diff --git a/.github/workflows/backport.yml b/.github/workflows/backport.yml index 56fef507..f4ddaa10 100644 --- a/.github/workflows/backport.yml +++ b/.github/workflows/backport.yml @@ -9,6 +9,17 @@ on: jobs: backport: runs-on: ubuntu-latest + # Only react to merged PRs for security reasons. + # See https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#pull_request_target. + if: > + github.event.pull_request.merged + && ( + github.event.action == 'closed' + || ( + github.event.action == 'labeled' + && contains(github.event.label.name, 'backport') + ) + ) permissions: contents: write pull-requests: write @@ -26,6 +37,6 @@ jobs: uses: VachaShah/backport@v2.2.0 with: github_token: ${{ steps.github_app_token.outputs.token }} - branch_name: backport/backport-${{ github.event.number }} + head_template: backport/backport-<%= number %>-to-<%= base %> labels_template: "<%= JSON.stringify([...labels, 'autocut']) %>" failure_labels: "failed backport" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3183a959..8edd403c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,10 +35,9 @@ jobs: needs: Get-CI-Image-Tag strategy: matrix: - java: - - 11 - - 17 - - 21 + java: [11, 17, 21] + env: + ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true name: Build and Test skills plugin on Linux runs-on: ubuntu-latest container: @@ -71,7 +70,7 @@ jobs: build-MacOS: strategy: matrix: - java: [ 11, 17 ] + java: [11, 17, 21] name: Build and Test skills Plugin on MacOS needs: Get-CI-Image-Tag @@ -90,15 +89,12 @@ jobs: export FC=/usr/local/Cellar/gcc/12.2.0/bin/gfortran - name: Run build run: | - ./gradlew build + ./gradlew build -Dos.arch=x86_64 build-windows: strategy: matrix: - java: - - 11 - - 17 - - 21 + java: [11, 17, 21] name: Build and Test skills plugin on Windows needs: Get-CI-Image-Tag runs-on: windows-latest diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml new file mode 100644 index 00000000..f5abac63 --- /dev/null +++ b/.github/workflows/labeler.yml @@ -0,0 +1,26 @@ +name: "Pull Request Labeler" +on: + pull_request_target: + branches: + - main + types: + - opened + +jobs: + label: + runs-on: ubuntu-latest + permissions: + contents: write + pull-requests: write + steps: + - name: GitHub App token + id: github_app_token + uses: tibdex/github-app-token@v2.1.0 + with: + app_id: ${{ secrets.APP_ID }} + private_key: ${{ secrets.APP_PRIVATE_KEY }} + installation_id: 22958780 + - name: Label + uses: actions/labeler@v5 + with: + repo-token: ${{ steps.github_app_token.outputs.token }} diff --git a/.github/workflows/maven-publish.yml b/.github/workflows/maven-publish.yml index 9786fa52..d07e5b31 100644 --- a/.github/workflows/maven-publish.yml +++ b/.github/workflows/maven-publish.yml @@ -11,7 +11,7 @@ jobs: build-and-publish-snapshots: strategy: fail-fast: false - if: github.repository == 'opensearch-project/agent-tools' + if: github.repository == 'opensearch-project/skills' runs-on: ubuntu-latest permissions: @@ -22,9 +22,9 @@ jobs: - uses: actions/setup-java@v3 with: distribution: temurin # Temurin is a distribution of adoptium - java-version: 17 + java-version: 21 - uses: actions/checkout@v3 - - uses: aws-actions/configure-aws-credentials@v1 + - uses: aws-actions/configure-aws-credentials@v4 with: role-to-assume: ${{ secrets.PUBLISH_SNAPSHOTS_ROLE }} aws-region: us-east-1 @@ -34,4 +34,4 @@ jobs: export SONATYPE_PASSWORD=$(aws secretsmanager get-secret-value --secret-id maven-snapshots-password --query SecretString --output text) echo "::add-mask::$SONATYPE_USERNAME" echo "::add-mask::$SONATYPE_PASSWORD" - ./gradlew publishShadowPublicationToSnapshotsRepository + ./gradlew publishPluginZipPublicationToSnapshotsRepository diff --git a/.github/workflows/test_security.yml b/.github/workflows/test_security.yml new file mode 100644 index 00000000..43274ed5 --- /dev/null +++ b/.github/workflows/test_security.yml @@ -0,0 +1,44 @@ +name: Run Security tests +on: + push: + branches-ignore: + - 'whitesource-remediate/**' + - 'backport/**' + pull_request: + types: [opened, synchronize, reopened] + +jobs: + Get-CI-Image-Tag: + uses: opensearch-project/opensearch-build/.github/workflows/get-ci-image-tag.yml@main + with: + product: opensearch + + integ-test-with-security-linux: + strategy: + matrix: + java: [11, 17, 21] + env: + ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true + name: Run Security Integration Tests on Linux + runs-on: ubuntu-latest + needs: Get-CI-Image-Tag + container: + # using the same image which is used by opensearch-build team to build the OpenSearch Distribution + # this image tag is subject to change as more dependencies and updates will arrive over time + image: ${{ needs.Get-CI-Image-Tag.outputs.ci-image-version-linux }} + # need to switch to root so that github actions can install runner binary on container without permission issues. + options: --user root + + steps: + - name: Checkout Skills + uses: actions/checkout@v3 + - name: Setup Java ${{ matrix.java }} + uses: actions/setup-java@v3 + with: + distribution: 'temurin' + java-version: ${{ matrix.java }} + - name: Run tests + # switching the user, as OpenSearch cluster can only be started as root/Administrator on linux-deb/linux-rpm/windows-zip. + run: | + chown -R 1000:1000 `pwd` + su `id -un 1000` -c "whoami && java -version && ./gradlew integTest -Dsecurity.enabled=true" \ No newline at end of file diff --git a/build.gradle b/build.gradle index 55aa878a..f163c41d 100644 --- a/build.gradle +++ b/build.gradle @@ -3,19 +3,36 @@ * SPDX-License-Identifier: Apache-2.0 */ +import org.opensearch.gradle.test.RestIntegTestTask +import java.util.concurrent.Callable +import java.nio.file.Paths +import com.github.jengelman.gradle.plugins.shadow.ShadowPlugin + buildscript { ext { opensearch_group = "org.opensearch" - opensearch_version = System.getProperty("opensearch.version", "2.11.0-SNAPSHOT") - isSnapshot = "true" == System.getProperty("build.snapshot", "true") + opensearch_version = System.getProperty("opensearch.version", "2.17.1-SNAPSHOT") buildVersionQualifier = System.getProperty("build.version_qualifier", "") + isSnapshot = "true" == System.getProperty("build.snapshot", "true") + version_tokens = opensearch_version.tokenize('-') + opensearch_build = version_tokens[0] + '.0' + plugin_no_snapshot = opensearch_build + if (buildVersionQualifier) { + opensearch_build += "-${buildVersionQualifier}" + plugin_no_snapshot += "-${buildVersionQualifier}" + } + if (isSnapshot) { + opensearch_build += "-SNAPSHOT" + } + opensearch_no_snapshot = opensearch_build.replace("-SNAPSHOT","") + kotlin_version = System.getProperty("kotlin.version", "1.8.21") } repositories { mavenLocal() - mavenCentral() - maven { url "https://plugins.gradle.org/m2/" } maven { url "https://aws.oss.sonatype.org/content/repositories/snapshots" } + maven { url "https://plugins.gradle.org/m2/" } + mavenCentral() } dependencies { @@ -25,26 +42,25 @@ buildscript { plugins { id 'java-library' - id 'com.diffplug.spotless' version '6.22.0' - id "io.freefair.lombok" version "8.0.1" + id 'com.diffplug.spotless' version '6.25.0' + id "io.freefair.lombok" version "8.10" + id "de.undercouch.download" version "5.6.0" +} + +lombok { + version = "1.18.34" } repositories { mavenLocal() - mavenCentral() - maven { url "https://plugins.gradle.org/m2/" } maven { url "https://aws.oss.sonatype.org/content/repositories/snapshots" } + maven { url "https://plugins.gradle.org/m2/" } + mavenCentral() } allprojects { - group 'org.opensearch.plugin' - version = opensearch_version.tokenize('-')[0] + '.0' - if (buildVersionQualifier) { - version += "-${buildVersionQualifier}" - } - if (isSnapshot) { - version += "-SNAPSHOT" - } + group = opensearch_group + version = "${opensearch_build}" } targetCompatibility = JavaVersion.VERSION_11 @@ -61,31 +77,84 @@ apply plugin: 'opensearch.opensearchplugin' apply plugin: 'opensearch.testclusters' apply plugin: 'opensearch.pluginzip' +def sqlJarDirectory = "$buildDir/dependencies/opensearch-sql-plugin" +def jsJarDirectory = "$buildDir/dependencies/opensearch-job-scheduler" +def adJarDirectory = "$buildDir/dependencies/opensearch-anomaly-detection" configurations { zipArchive + secureIntegTestPluginArchive all { - resolutionStrategy.force "org.mockito:mockito-core:5.5.0" + resolutionStrategy { + force "org.mockito:mockito-core:${versions.mockito}" + force "com.google.guava:guava:33.2.1-jre" // CVE for 31.1 + force("org.eclipse.platform:org.eclipse.core.runtime:3.30.0") // CVE for < 3.29.0, forces JDK17 for spotless + } } } -def sqlJarDirectory = "$buildDir/dependencies/opensearch-sql-plugin" +task addJarsToClasspath(type: Copy) { + from(fileTree(dir: sqlJarDirectory)) { + include "opensearch-sql-${opensearch_build}.jar" + include "ppl-${opensearch_build}.jar" + include "protocol-${opensearch_build}.jar" + } + into("$buildDir/classes") + + from(fileTree(dir: jsJarDirectory)) { + include "opensearch-job-scheduler-${opensearch_build}.jar" + } + into("$buildDir/classes") + + from(fileTree(dir: adJarDirectory)) { + include "opensearch-anomaly-detection-${opensearch_build}.jar" + } + into("$buildDir/classes") +} dependencies { - compileOnly group: 'org.opensearch', name:'opensearch-ml-client', version: "${version}" + // 3P dependencies compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1' + compileOnly "org.apache.logging.log4j:log4j-slf4j-impl:2.23.1" + compileOnly group: 'org.json', name: 'json', version: '20240205' + compileOnly("com.google.guava:guava:33.2.1-jre") + compileOnly group: 'org.apache.commons', name: 'commons-lang3', version: '3.16.0' + compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.11.0' + + // Plugin dependencies + compileOnly group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" + implementation fileTree(dir: jsJarDirectory, include: ["opensearch-job-scheduler-${opensearch_build}.jar"]) + implementation fileTree(dir: adJarDirectory, include: ["opensearch-anomaly-detection-${opensearch_build}.jar"]) + implementation fileTree(dir: sqlJarDirectory, include: ["opensearch-sql-${opensearch_build}.jar", "ppl-${opensearch_build}.jar", "protocol-${opensearch_build}.jar"]) + compileOnly "org.opensearch:common-utils:${opensearch_build}" + compileOnly "org.jetbrains.kotlin:kotlin-stdlib:${kotlin_version}" + compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}" + + + // ZipArchive dependencies used for integration tests + zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}" + zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${opensearch_build}" + zipArchive group: 'org.opensearch.plugin', name:'opensearch-anomaly-detection', version: "${opensearch_build}" + zipArchive group: 'org.opensearch.plugin', name:'opensearch-sql-plugin', version: "${opensearch_build}" + zipArchive group: 'org.opensearch.plugin', name:'opensearch-knn', version: "${opensearch_build}" + zipArchive group: 'org.opensearch.plugin', name:'neural-search', version: "${opensearch_build}" + zipArchive group: 'org.opensearch.plugin', name:'alerting', version: "${opensearch_build}" + secureIntegTestPluginArchive group: 'org.opensearch.plugin', name:'opensearch-security', version: "${opensearch_build}" + + // Test dependencies testImplementation "org.opensearch.test:framework:${opensearch_version}" - testImplementation "org.mockito:mockito-core:3.10.0" - testImplementation 'org.junit.jupiter:junit-jupiter-api:5.7.2' - testImplementation 'org.mockito:mockito-junit-jupiter:3.10.0' + testImplementation group: 'junit', name: 'junit', version: '4.13.2' + testImplementation group: 'org.json', name: 'json', version: '20240205' + testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.10.0' + testImplementation group: 'org.mockito', name: 'mockito-inline', version: '5.2.0' + testImplementation("net.bytebuddy:byte-buddy:1.14.9") + testImplementation("net.bytebuddy:byte-buddy-agent:1.14.7") + testImplementation 'org.junit.jupiter:junit-jupiter-api:5.10.1' + testImplementation 'org.mockito:mockito-junit-jupiter:5.10.0' testImplementation "com.nhaarman.mockitokotlin2:mockito-kotlin:2.2.0" - testImplementation "com.cronutils:cron-utils:9.1.6" - testImplementation "commons-validator:commons-validator:1.7" - testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.7.2' - compileOnly "org.apache.logging.log4j:log4j-slf4j-impl:2.19.0" - compileOnly group: 'org.json', name: 'json', version: '20230227' - zipArchive group: 'org.opensearch.plugin', name:'opensearch-sql-plugin', version: "${version}" - implementation fileTree(dir: sqlJarDirectory, include: ["opensearch-sql-${version}.jar", "ppl-${version}.jar", "protocol-${version}.jar"]) + testImplementation "com.cronutils:cron-utils:9.2.1" + testImplementation "commons-validator:commons-validator:1.8.0" + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.10.1' } task extractSqlJar(type: Copy) { @@ -94,35 +163,67 @@ task extractSqlJar(type: Copy) { into sqlJarDirectory } -project.tasks.delombok.dependsOn(extractSqlJar) +task extractJsJar(type: Copy) { + mustRunAfter() + from(zipTree(configurations.zipArchive.find { it.name.startsWith("opensearch-job-scheduler")})) + into jsJarDirectory +} + +task extractAdJar(type: Copy) { + mustRunAfter() + from(zipTree(configurations.zipArchive.find { it.name.startsWith("opensearch-anomaly-detection")})) + into adJarDirectory +} + +tasks.addJarsToClasspath.dependsOn(extractSqlJar) +tasks.addJarsToClasspath.dependsOn(extractJsJar) +tasks.addJarsToClasspath.dependsOn(extractAdJar) +project.tasks.delombok.dependsOn(addJarsToClasspath) tasks.publishNebulaPublicationToMavenLocal.dependsOn ':generatePomFileForPluginZipPublication' tasks.validateNebulaPom.dependsOn ':generatePomFileForPluginZipPublication' dependencyLicenses.enabled = false loggerUsageCheck.enabled = false testingConventions.enabled = false +thirdPartyAudit.enabled = false +publishNebulaPublicationToMavenLocal.enabled = false test { - useJUnitPlatform() testLogging { exceptionFormat "full" events "skipped", "passed", "failed" // "started" showStandardStreams true } + include '**/*Tests.class' + systemProperty 'tests.security.manager', 'false' } -spotless { - java { - removeUnusedImports() - importOrder 'java', 'javax', 'org', 'com' - licenseHeaderFile 'spotless.license.java' +jacocoTestReport { + dependsOn test + reports { + html.required = true // human readable + xml.required = true // for coverlay + } +} - eclipse().configFile rootProject.file('.eclipseformat.xml') +spotless { + if (JavaVersion.current() >= JavaVersion.VERSION_17) { + // Spotless configuration for Java files + java { + removeUnusedImports() + importOrder 'java', 'javax', 'org', 'com' + licenseHeaderFile 'spotless.license.java' + eclipse().configFile rootProject.file('.eclipseformat.xml') + } + } else { + logger.lifecycle("Spotless plugin requires Java 17 or higher. Skipping Spotless tasks.") } } compileJava { dependsOn extractSqlJar + dependsOn extractJsJar + dependsOn extractAdJar dependsOn delombok options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor']) } @@ -131,36 +232,69 @@ compileTestJava { options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor']) } +forbiddenApisTest.ignoreFailures = true + opensearchplugin { - name 'agent-tools' - description 'OpenSearch Agent Tools' - classname 'org.opensearch.agent_tool.ToolPlugin' + name 'opensearch-skills' + description 'OpenSearch Skills' + classname 'org.opensearch.agent.ToolPlugin' extendedPlugins = ['opensearch-ml'] licenseFile rootProject.file("LICENSE.txt") noticeFile rootProject.file("NOTICE") } -publishing { - repositories { - maven { - name = 'staging' - url = "${rootProject.buildDir}/local-staging-repo" - } - maven { - name = "Snapshots" - url = "https://aws.oss.sonatype.org/content/repositories/snapshots" - credentials { - username "$System.env.SONATYPE_USERNAME" - password "$System.env.SONATYPE_PASSWORD" +def opensearch_tmp_dir = rootProject.file('build/private/opensearch_tmp').absoluteFile +opensearch_tmp_dir.mkdirs() + +ext { + projectSubstitutions = [:] + isSnapshot = "true" == System.getProperty("build.snapshot", "true") +} + +allprojects { + // Default to the apache license + project.ext.licenseName = 'The Apache Software License, Version 2.0' + project.ext.licenseUrl = 'http://www.apache.org/licenses/LICENSE-2.0.txt' + plugins.withType(ShadowPlugin).whenPluginAdded { + publishing { + repositories { + maven { + name = 'staging' + url = "${rootProject.buildDir}/local-staging-repo" + } + } + publications { + // add license information to generated poms + all { + pom { + name = "skills" + description = "Tools for Agent Framework" + } + pom.withXml { XmlProvider xml -> + Node node = xml.asNode() + node.appendNode('inceptionYear', '2021') + + Node license = node.appendNode('licenses').appendNode('license') + license.appendNode('name', project.licenseName) + license.appendNode('url', project.licenseUrl) + + Node developer = node.appendNode('developers').appendNode('developer') + developer.appendNode('name', 'OpenSearch') + developer.appendNode('url', 'https://github.com/opensearch-project/')skills + } + } } } } +} + +publishing { publications { pluginZip(MavenPublication) { publication -> pom { - name = "OpenSearch Agent Tools" - description = "OpenSearch Agent Tools" + name = "OpenSearch Skills" + description = "OpenSearch Skills" groupId = "org.opensearch.plugin" licenses { license { @@ -171,16 +305,259 @@ publishing { developers { developer { name = "OpenSearch" - url = "https://github.com/opensearch-project/agent-tools" + url = "https://github.com/opensearch-project/skills" } } } } } + repositories { + maven { + name = "Snapshots" + url = "https://aws.oss.sonatype.org/content/repositories/snapshots" + credentials { + username "$System.env.SONATYPE_USERNAME" + password "$System.env.SONATYPE_PASSWORD" + } + } + } gradle.startParameter.setShowStacktrace(ShowStacktrace.ALWAYS) gradle.startParameter.setLogLevel(LogLevel.DEBUG) } + +def _numNodes = findProperty('numNodes') as Integer ?: 1 + +// Set up integration tests +task integTest(type: RestIntegTestTask) { + description = "Run tests against a cluster" + testClassesDirs = sourceSets.test.output.classesDirs + classpath = sourceSets.test.runtimeClasspath +} +tasks.named("check").configure { dependsOn(integTest) } + +integTest { + + dependsOn "bundlePlugin" + systemProperty 'tests.security.manager', 'false' + systemProperty 'java.io.tmpdir', opensearch_tmp_dir.absolutePath + systemProperty('project.root', project.rootDir.absolutePath) + systemProperty "https", System.getProperty("https") + systemProperty "user", System.getProperty("user") + systemProperty "password", System.getProperty("password") + + systemProperty 'security.enabled', System.getProperty('security.enabled') + var is_https = System.getProperty("https") + var user = System.getProperty("user") + var password = System.getProperty("password") + + if (System.getProperty("security.enabled") != null) { + // If security is enabled, set is_https/user/password defaults + // admin password is permissable here since the security plugin is manually configured using the default internal_users.yml configuration + is_https = is_https == null ? "true" : is_https + user = user == null ? "admin" : user + password = password == null ? "admin" : password + System.setProperty("https", is_https) + System.setProperty("user", user) + System.setProperty("password", password) + } + + systemProperty("https", is_https) + systemProperty("user", user) + systemProperty("password", password) + + // Certain integ tests require system index manipulation to properly test. We exclude those + // in the security-enabled scenario since this action is prohibited by security plugin. + if (System.getProperty("https") != null && System.getProperty("https") == "true") { + filter { + excludeTestsMatching "org.opensearch.integTest.SearchAlertsToolIT" + excludeTestsMatching "org.opensearch.integTest.SearchAnomalyResultsToolIT" + } + } + + + // doFirst delays this block until execution time + doFirst { + // Tell the test JVM if the cluster JVM is running under a debugger so that tests can + // use longer timeouts for requests. + def isDebuggingCluster = getDebug() || System.getProperty("test.debug") != null + systemProperty 'cluster.debug', isDebuggingCluster + // Set number of nodes system property to be used in tests + systemProperty 'cluster.number_of_nodes', "${_numNodes}" + // There seems to be an issue when running multi node run or integ tasks with unicast_hosts + // not being written, the waitForAllConditions ensures it's written + getClusters().forEach { cluster -> + cluster.waitForAllConditions() + } + } + + // The --debug-jvm command-line option makes the cluster debuggable; this makes the tests debuggable + if (System.getProperty("test.debug") != null) { + jvmArgs '-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=*:5005' + } +} + +// Set up integration test clusters, installs all zipArchive dependencies and this plugin +testClusters.integTest { + testDistribution = "ARCHIVE" + + // Optionally install security + if (System.getProperty("security.enabled") != null) { + configurations.secureIntegTestPluginArchive.asFileTree.each { + if(it.name.contains("opensearch-security")){ + plugin(provider(new Callable() { + @Override + RegularFile call() throws Exception { + return new RegularFile() { + @Override + File getAsFile() { + return it + } + } + } + })) + } + } + + getNodes().forEach { node -> + var creds = node.getCredentials() + // admin password is permissable here since the security plugin is manually configured using the default internal_users.yml configuration + if (creds.isEmpty()) { + creds.add(Map.of('username', 'admin', 'password', 'admin')) + } else { + creds.get(0).putAll(Map.of('username', 'admin', 'password', 'admin')) + } + } + + // Config below including files are copied from security demo configuration + ['esnode.pem', 'esnode-key.pem', 'root-ca.pem','kirk.pem','kirk-key.pem'].forEach { file -> + File local = Paths.get(opensearch_tmp_dir.absolutePath, file).toFile() + download.run { + src "https://raw.githubusercontent.com/opensearch-project/security/main/bwc-test/src/test/resources/security/" + file + dest local + overwrite false + } + } + + // // Config below including files are copied from security demo configuration + extraConfigFile("esnode.pem", file("$opensearch_tmp_dir/esnode.pem")) + extraConfigFile("esnode-key.pem", file("$opensearch_tmp_dir/esnode-key.pem")) + extraConfigFile("root-ca.pem", file("$opensearch_tmp_dir/root-ca.pem")) + + // This configuration is copied from the security plugins demo install: + // https://github.com/opensearch-project/security/blob/2.11.1.0/tools/install_demo_configuration.sh#L365-L388 + setting("plugins.security.ssl.transport.pemcert_filepath", "esnode.pem") + setting("plugins.security.ssl.transport.pemkey_filepath", "esnode-key.pem") + setting("plugins.security.ssl.transport.pemtrustedcas_filepath", "root-ca.pem") + setting("plugins.security.ssl.transport.enforce_hostname_verification", "false") + setting("plugins.security.ssl.http.enabled", "true") + setting("plugins.security.ssl.http.pemcert_filepath", "esnode.pem") + setting("plugins.security.ssl.http.pemkey_filepath", "esnode-key.pem") + setting("plugins.security.ssl.http.pemtrustedcas_filepath", "root-ca.pem") + setting("plugins.security.allow_unsafe_democertificates", "true") + setting("plugins.security.allow_default_init_securityindex", "true") + setting("plugins.security.unsupported.inject_user.enabled", "true") + + setting("plugins.security.authcz.admin_dn", "\n- CN=kirk,OU=client,O=client,L=test, C=de") + setting('plugins.security.restapi.roles_enabled', '["all_access", "security_rest_api_access"]') + setting('plugins.security.system_indices.enabled', "true") + setting('plugins.security.system_indices.indices', '[' + + '".plugins-ml-config", ' + + '".plugins-ml-connector", ' + + '".plugins-ml-model-group", ' + + '".plugins-ml-model", ".plugins-ml-task", ' + + '".plugins-ml-conversation-meta", ' + + '".plugins-ml-conversation-interactions", ' + + '".opendistro-alerting-config", ' + + '".opendistro-alerting-alert*", ' + + '".opendistro-anomaly-results*", ' + + '".opendistro-anomaly-detector*", ' + + '".opendistro-anomaly-checkpoints", ' + + '".opendistro-anomaly-detection-state", ' + + '".opendistro-reports-*", ' + + '".opensearch-notifications-*", ' + + '".opensearch-notebooks", ' + + '".opensearch-observability", ' + + '".ql-datasources", ' + + '".opendistro-asynchronous-search-response*", ' + + '".replication-metadata-store", ' + + '".opensearch-knn-models", ' + + '".geospatial-ip2geo-data*", ' + + '".plugins-flow-framework-config", ' + + '".plugins-flow-framework-templates", ' + + '".plugins-flow-framework-state"' + + ']' + ) + setSecure(true) + } + + // Installs all registered zipArchive dependencies on integTest cluster nodes + configurations.zipArchive.asFileTree.each { + plugin(provider(new Callable(){ + @Override + RegularFile call() throws Exception { + return new RegularFile() { + @Override + File getAsFile() { + return it + } + } + } + })) + } + + // Install skills plugin on integTest cluster nodes + plugin(project.tasks.bundlePlugin.archiveFile) + + // Cluster shrink exception thrown if we try to set numberOfNodes to 1, so only apply if > 1 + if (_numNodes > 1) numberOfNodes = _numNodes + + // When running integration tests it doesn't forward the --debug-jvm to the cluster anymore + // i.e. we have to use a custom property to flag when we want to debug OpenSearch JVM + // since we also support multi node integration tests we increase debugPort per node + if (System.getProperty("opensearch.debug") != null) { + def debugPort = 5005 + nodes.forEach { node -> + node.jvmArgs("-agentlib:jdwp=transport=dt_socket,server=n,suspend=y,address=*:${debugPort}") + debugPort += 1 + } + } +} + +// Remote Integration Tests +task integTestRemote(type: RestIntegTestTask) { + testClassesDirs = sourceSets.test.output.classesDirs + classpath = sourceSets.test.runtimeClasspath + + systemProperty "https", System.getProperty("https") + systemProperty "user", System.getProperty("user") + systemProperty "password", System.getProperty("password") + + systemProperty 'cluster.number_of_nodes', "${_numNodes}" + + systemProperty 'tests.security.manager', 'false' + // Run tests with remote cluster only if rest case is defined + if (System.getProperty("tests.rest.cluster") != null) { + filter { + includeTestsMatching "org.opensearch.integTest.*IT" + } + } + + // Certain integ tests require system index manipulation to properly test. We exclude those + // in the security-enabled scenario since this action is prohibited by security plugin. + if (System.getProperty("https") != null && System.getProperty("https") == "true") { + filter { + excludeTestsMatching "org.opensearch.integTest.SearchAlertsToolIT" + excludeTestsMatching "org.opensearch.integTest.SearchAnomalyResultsToolIT" + } + } +} + +// Automatically sets up the integration test cluster locally +run { + useCluster testClusters.integTest +} + // updateVersion: Task to auto increment to the next development iteration task updateVersion { onlyIf { System.getProperty('newVersion') } diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index 7f93135c..a4b76b95 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 3999f7f3..2b189974 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,7 +1,8 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-8.4-bin.zip +distributionSha256Sum=5b9c5eb3f9fc2c94abaea57d90bd78747ca117ddbbf96c859d3741181a12bf2a +distributionUrl=https\://services.gradle.org/distributions/gradle-8.10-bin.zip networkTimeout=10000 +validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists -distributionSha256Sum=3e1af3ae886920c3ac87f7a91f816c0c7c436f276a6eefdb3da152100fef72ae diff --git a/gradlew b/gradlew index 1aa94a42..f5feea6d 100755 --- a/gradlew +++ b/gradlew @@ -15,6 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +# SPDX-License-Identifier: Apache-2.0 +# ############################################################################## # @@ -55,7 +57,7 @@ # Darwin, MinGW, and NonStop. # # (3) This script is generated from the Groovy template -# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# https://github.com/gradle/gradle/blob/HEAD/platforms/jvm/plugins-application/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt # within the Gradle project. # # You can find Gradle at https://github.com/gradle/gradle/. @@ -84,7 +86,8 @@ done # shellcheck disable=SC2034 APP_BASE_NAME=${0##*/} # Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) -APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit +APP_HOME=$( cd -P "${APP_HOME:-./}" > /dev/null && printf '%s +' "$PWD" ) || exit # Use the maximum available, or set MAX_FD != -1 to use that value. MAX_FD=maximum diff --git a/gradlew.bat b/gradlew.bat index 6689b85b..0ebb4c6c 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -13,6 +13,8 @@ @rem See the License for the specific language governing permissions and @rem limitations under the License. @rem +@rem SPDX-License-Identifier: Apache-2.0 +@rem @if "%DEBUG%"=="" @echo off @rem ########################################################################## diff --git a/release-notes/opensearch-skills.release-notes-2.12.0.0.md b/release-notes/opensearch-skills.release-notes-2.12.0.0.md new file mode 100644 index 00000000..4c763641 --- /dev/null +++ b/release-notes/opensearch-skills.release-notes-2.12.0.0.md @@ -0,0 +1,6 @@ +# 2024-02-08 Version 2.12.0.0 + +Compatible with OpenSearch 2.12.0 + +### Features +* Initial release of Skills \ No newline at end of file diff --git a/release-notes/opensearch-skills.release-notes-2.13.0.0.md b/release-notes/opensearch-skills.release-notes-2.13.0.0.md new file mode 100644 index 00000000..04a0b668 --- /dev/null +++ b/release-notes/opensearch-skills.release-notes-2.13.0.0.md @@ -0,0 +1,18 @@ +# 2024-03-20 Version 2.13.0.0 + +Compatible with OpenSearch 2.13.0 + +### Features +* Fix SearchAnomalyDetectorsTool indices param bug +* Fix detector state params in SearchAnomalyDetectorsTool +* Update ppl tool claude model prompts to use tags +* Add parameter validation for PPL tool + +### Dependencies +* Update mockito monorepo to v5.10.0 (#128) (#197) +* Update dependency org.apache.commons:commons-lang3 to v3.14.0 (#47) +* Update dependency org.apache.commons:commons-text to v1.11.0 (#62) +* Update plugin io.freefair.lombok to v8.6 (#245) (#249) +* Update plugin de.undercouch.download to v5.6.0 (#239) (#250) +* Update plugin com.diffplug.spotless to v6.25.0 (#127) (#252) +* Update dependency org.json:json to v20240205 (#246) (#251) diff --git a/release-notes/opensearch-skills.release-notes-2.14.0.0.md b/release-notes/opensearch-skills.release-notes-2.14.0.0.md new file mode 100644 index 00000000..5617b850 --- /dev/null +++ b/release-notes/opensearch-skills.release-notes-2.14.0.0.md @@ -0,0 +1,13 @@ +# 2024-04-29 Version 2.14.0.0 + +Compatible with OpenSearch 2.14.0 + +### Features +* Fix filter fields, adding geo point and date_nanos (#285) (#286) +* Change ad plugin jar dependency (#288) +* Remove logic about replace quota for finetuning model (#289) (#291) +* Move search index tool to ml-commons repo (#297) +* Move visualization tool to ml-commons (#296) (#298) + +### Dependencies +* Increment byte-buddy version to 1.14.9 (#288) diff --git a/release-notes/opensearch-skills.release-notes-2.15.0.0.md b/release-notes/opensearch-skills.release-notes-2.15.0.0.md new file mode 100644 index 00000000..f8a03999 --- /dev/null +++ b/release-notes/opensearch-skills.release-notes-2.15.0.0.md @@ -0,0 +1,6 @@ +# 2024-06-11 Version 2.15.0.0 + +Compatible with OpenSearch 2.15.0 + +### Maintenance +Increment version to 2.15.0.0. diff --git a/release-notes/opensearch-skills.release-notes-2.16.0.0.md b/release-notes/opensearch-skills.release-notes-2.16.0.0.md new file mode 100644 index 00000000..2a474746 --- /dev/null +++ b/release-notes/opensearch-skills.release-notes-2.16.0.0.md @@ -0,0 +1,8 @@ +# 2024-07-29 Version 2.16.0.0 + +Compatible with OpenSearch 2.16.0 + +### Features +* support nested query in neural sparse tool, vectorDB tool and RAG tool ([#350](https://github.com/opensearch-project/skills/pull/350)) +* Add cluster setting to control ppl execution ([#344](https://github.com/opensearch-project/skills/pull/344)) +* Add CreateAnomalyDetectorTool ([#348](https://github.com/opensearch-project/skills/pull/348)) diff --git a/release-notes/opensearch-skills.release-notes-2.17.0.0.md b/release-notes/opensearch-skills.release-notes-2.17.0.0.md new file mode 100644 index 00000000..46199f69 --- /dev/null +++ b/release-notes/opensearch-skills.release-notes-2.17.0.0.md @@ -0,0 +1,12 @@ +# 2024-09-07 Version 2.17.0.0 + +Compatible with OpenSearch 2.17.0 + +### Maintenance +update dependency org.apache.logging.log4j:log4j-slf4j-impl to v2.23.1 ([#256](https://github.com/opensearch-project/skills/pull/256)) +update dependency com.google.guava:guava to v33.2.1-jre ([#258](https://github.com/opensearch-project/skills/pull/258)) +Upgrade apache common lang version to 3.16 ([#371](https://github.com/opensearch-project/skills/pull/371)) +update dependency gradle to v8.10 ([#389](https://github.com/opensearch-project/skills/pull/389)) +update plugin io.freefair.lombok to v8.10 ([#393](https://github.com/opensearch-project/skills/pull/393)) + + diff --git a/scripts/build.sh b/scripts/build.sh index e0495d4a..490d93fe 100755 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -64,8 +64,13 @@ fi [[ "$SNAPSHOT" == "true" ]] && VERSION=$VERSION-SNAPSHOT [ -z "$OUTPUT" ] && OUTPUT=artifacts -./gradlew build -x test -Dopensearch.version=$VERSION -Dbuild.snapshot=$SNAPSHOT -Dbuild.version_qualifier=$QUALIFIER -./gradlew publishShadowPublicationToMavenLocal -Dopensearch.version=$VERSION -Dbuild.snapshot=$SNAPSHOT -Dbuild.version_qualifier=$QUALIFIER -./gradlew publishShadowPublicationToStagingRepository -Dopensearch.version=$VERSION -Dbuild.snapshot=$SNAPSHOT -Dbuild.version_qualifier=$QUALIFIER +./gradlew assemble -x test -Dopensearch.version=$VERSION -Dbuild.snapshot=$SNAPSHOT -Dbuild.version_qualifier=$QUALIFIER +./gradlew publishToMavenLocal -Dopensearch.version=$VERSION -Dbuild.snapshot=$SNAPSHOT -Dbuild.version_qualifier=$QUALIFIER +./gradlew publishPluginZipPublicationToZipStagingRepository -Dopensearch.version=$VERSION -Dbuild.snapshot=$SNAPSHOT -Dbuild.version_qualifier=$QUALIFIER mkdir -p $OUTPUT/maven/org/opensearch cp -r ./build/local-staging-repo/org/opensearch/. $OUTPUT/maven/org/opensearch + +mkdir -p $OUTPUT/plugins +zipPath=$(find . -path \*build/distributions/*.zip) +distributions="$(dirname "${zipPath}")" +cp ${distributions}/*.zip ./$OUTPUT/plugins diff --git a/settings.gradle b/settings.gradle index 26eac06a..bea9b83a 100644 --- a/settings.gradle +++ b/settings.gradle @@ -3,4 +3,4 @@ * SPDX-License-Identifier: Apache-2.0 */ -rootProject.name = 'agent-tools' +rootProject.name = 'opensearch-skills' diff --git a/src/main/java/org/opensearch/agent/ToolPlugin.java b/src/main/java/org/opensearch/agent/ToolPlugin.java index eba2f6a1..ac0aa484 100644 --- a/src/main/java/org/opensearch/agent/ToolPlugin.java +++ b/src/main/java/org/opensearch/agent/ToolPlugin.java @@ -10,7 +10,15 @@ import java.util.List; import java.util.function.Supplier; +import org.opensearch.agent.tools.CreateAnomalyDetectorTool; +import org.opensearch.agent.tools.NeuralSparseSearchTool; import org.opensearch.agent.tools.PPLTool; +import org.opensearch.agent.tools.RAGTool; +import org.opensearch.agent.tools.SearchAlertsTool; +import org.opensearch.agent.tools.SearchAnomalyDetectorsTool; +import org.opensearch.agent.tools.SearchAnomalyResultsTool; +import org.opensearch.agent.tools.SearchMonitorsTool; +import org.opensearch.agent.tools.VectorDBTool; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; @@ -52,13 +60,32 @@ public Collection createComponents( this.client = client; this.clusterService = clusterService; this.xContentRegistry = xContentRegistry; - PPLTool.Factory.getInstance().init(client); + NeuralSparseSearchTool.Factory.getInstance().init(client, xContentRegistry); + VectorDBTool.Factory.getInstance().init(client, xContentRegistry); + RAGTool.Factory.getInstance().init(client, xContentRegistry); + SearchAlertsTool.Factory.getInstance().init(client); + SearchAnomalyDetectorsTool.Factory.getInstance().init(client, namedWriteableRegistry); + SearchAnomalyResultsTool.Factory.getInstance().init(client, namedWriteableRegistry); + SearchMonitorsTool.Factory.getInstance().init(client); + CreateAnomalyDetectorTool.Factory.getInstance().init(client); return Collections.emptyList(); } @Override public List> getToolFactories() { - return List.of(PPLTool.Factory.getInstance()); + return List + .of( + PPLTool.Factory.getInstance(), + NeuralSparseSearchTool.Factory.getInstance(), + VectorDBTool.Factory.getInstance(), + RAGTool.Factory.getInstance(), + SearchAlertsTool.Factory.getInstance(), + SearchAnomalyDetectorsTool.Factory.getInstance(), + SearchAnomalyResultsTool.Factory.getInstance(), + SearchMonitorsTool.Factory.getInstance(), + CreateAnomalyDetectorTool.Factory.getInstance() + ); } + } diff --git a/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java b/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java new file mode 100644 index 00000000..f01dde7e --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/AbstractRetrieverTool.java @@ -0,0 +1,148 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.util.HashMap; +import java.util.Map; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * Abstract tool supports search paradigms in neural-search plugin. + */ +@Log4j2 +@Getter +@Setter +public abstract class AbstractRetrieverTool implements Tool { + public static final String DEFAULT_DESCRIPTION = "Use this tool to search data in OpenSearch index."; + public static final String INPUT_FIELD = "input"; + public static final String INDEX_FIELD = "index"; + public static final String SOURCE_FIELD = "source_field"; + public static final String DOC_SIZE_FIELD = "doc_size"; + public static final int DEFAULT_DOC_SIZE = 2; + + protected String description = DEFAULT_DESCRIPTION; + protected Client client; + protected NamedXContentRegistry xContentRegistry; + protected String index; + protected String[] sourceFields; + protected Integer docSize; + protected String version; + + protected AbstractRetrieverTool( + Client client, + NamedXContentRegistry xContentRegistry, + String index, + String[] sourceFields, + Integer docSize + ) { + this.client = client; + this.xContentRegistry = xContentRegistry; + this.index = index; + this.sourceFields = sourceFields; + this.docSize = docSize == null ? DEFAULT_DOC_SIZE : docSize; + } + + protected abstract String getQueryBody(String queryText); + + private static Map processResponse(SearchHit hit) { + Map docContent = new HashMap<>(); + docContent.put("_index", hit.getIndex()); + docContent.put("_id", hit.getId()); + docContent.put("_score", hit.getScore()); + docContent.put("_source", hit.getSourceAsMap()); + return docContent; + } + + private SearchRequest buildSearchRequest(Map parameters) throws IOException { + String question = parameters.get(INPUT_FIELD); + if (StringUtils.isBlank(question)) { + throw new IllegalArgumentException("[" + INPUT_FIELD + "] is null or empty, can not process it."); + } + + String query = getQueryBody(question); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query); + searchSourceBuilder.parseXContent(queryParser); + searchSourceBuilder.fetchSource(sourceFields, null); + searchSourceBuilder.size(docSize); + SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index); + return searchRequest; + } + + @Override + public void run(Map parameters, ActionListener listener) { + SearchRequest searchRequest; + try { + searchRequest = buildSearchRequest(parameters); + } catch (Exception e) { + log.error("Failed to build search request.", e); + listener.onFailure(e); + return; + } + + ActionListener actionListener = ActionListener.wrap(r -> { + SearchHit[] hits = r.getHits().getHits(); + + if (hits != null && hits.length > 0) { + StringBuilder contextBuilder = new StringBuilder(); + for (SearchHit hit : hits) { + Map docContent = processResponse(hit); + String docContentInString = AccessController + .doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(docContent)); + contextBuilder.append(docContentInString).append("\n"); + } + listener.onResponse((T) contextBuilder.toString()); + } else { + listener.onResponse((T) "Can not get any match from search result."); + } + }, e -> { + log.error("Failed to search index.", e); + listener.onFailure(e); + }); + client.search(searchRequest, actionListener); + } + + @Override + public boolean validate(Map parameters) { + return parameters != null && parameters.size() > 0 && !StringUtils.isBlank(parameters.get("input")); + } + + protected static abstract class Factory implements Tool.Factory { + protected Client client; + protected NamedXContentRegistry xContentRegistry; + + public void init(Client client, NamedXContentRegistry xContentRegistry) { + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + } +} diff --git a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java new file mode 100644 index 00000000..bd018698 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java @@ -0,0 +1,453 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.StringJoiner; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import org.apache.commons.text.StringSubstitutor; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest; +import org.opensearch.agent.tools.utils.ToolHelper; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; + +import com.google.common.collect.ImmutableMap; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * A tool used to help creating anomaly detector, the only one input parameter is the index name, this tool will get the mappings of the index + * in flight and let LLM give the suggested category field, aggregation field and correspond aggregation method which are required for the create + * anomaly detector API, the output of this tool is like: + *{ + * "index": "opensearch_dashboards_sample_data_ecommerce", + * "categoryField": "geoip.country_iso_code", + * "aggregationField": "total_quantity,total_unique_products,taxful_total_price", + * "aggregationMethod": "sum,count,sum", + * "dateFields": "customer_birth_date,order_date,products.created_on" + * } + */ +@Log4j2 +@Setter +@Getter +@ToolAnnotation(CreateAnomalyDetectorTool.TYPE) +public class CreateAnomalyDetectorTool implements Tool { + // the type of this tool + public static final String TYPE = "CreateAnomalyDetectorTool"; + + // the default description of this tool + private static final String DEFAULT_DESCRIPTION = + "This is a tool used to help creating anomaly detector. It takes a required argument which is the name of the index, extract the index mappings and let the LLM to give the suggested aggregation field, aggregation method, category field and the date field which are required to create an anomaly detector."; + // the regex used to extract the key information from the response of LLM + private static final String EXTRACT_INFORMATION_REGEX = + "(?s).*\\{category_field=([^|]*)\\|aggregation_field=([^|]*)\\|aggregation_method=([^}]*)}.*"; + // valid field types which support aggregation + private static final Set VALID_FIELD_TYPES = Set + .of( + "keyword", + "constant_keyword", + "wildcard", + "long", + "integer", + "short", + "byte", + "double", + "float", + "half_float", + "scaled_float", + "unsigned_long", + "ip" + ); + // the index name key in the output + private static final String OUTPUT_KEY_INDEX = "index"; + // the category field key in the output + private static final String OUTPUT_KEY_CATEGORY_FIELD = "categoryField"; + // the aggregation field key in the output + private static final String OUTPUT_KEY_AGGREGATION_FIELD = "aggregationField"; + // the aggregation method name key in the output + private static final String OUTPUT_KEY_AGGREGATION_METHOD = "aggregationMethod"; + // the date fields key in the output + private static final String OUTPUT_KEY_DATE_FIELDS = "dateFields"; + // the default prompt dictionary, includes claude and openai + private static final Map DEFAULT_PROMPT_DICT = loadDefaultPromptFromFile(); + // the name of this tool + @Setter + @Getter + private String name = TYPE; + // the description of this tool + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + + // the version of this tool + @Getter + private String version; + + // the OpenSearch transport client + private Client client; + // the mode id of LLM + @Getter + private String modelId; + // LLM model type, CLAUDE or OPENAI + @Getter + private ModelType modelType; + // the default prompt for creating anomaly detector + private String contextPrompt; + + enum ModelType { + CLAUDE, + OPENAI; + + public static ModelType from(String value) { + return valueOf(value.toUpperCase(Locale.ROOT)); + } + + } + + /** + * + * @param client the OpenSearch transport client + * @param modelId the model ID of LLM + */ + public CreateAnomalyDetectorTool(Client client, String modelId, String modelType) { + this.client = client; + this.modelId = modelId; + if (!ModelType.OPENAI.toString().equalsIgnoreCase(modelType) && !ModelType.CLAUDE.toString().equalsIgnoreCase(modelType)) { + throw new IllegalArgumentException("Unsupported model_type: " + modelType); + } + this.modelType = ModelType.from(modelType); + this.contextPrompt = DEFAULT_PROMPT_DICT.getOrDefault(this.modelType.toString(), ""); + } + + /** + * The main running method of this tool + * @param parameters the input parameters + * @param listener the action listener + * + */ + @Override + public void run(Map parameters, ActionListener listener) { + Map enrichedParameters = enrichParameters(parameters); + String indexName = enrichedParameters.get("index"); + if (Strings.isNullOrEmpty(indexName)) { + throw new IllegalArgumentException( + "Return this final answer to human directly and do not use other tools: 'Please provide index name'. Please try to directly send this message to human to ask for index name" + ); + } + if (indexName.startsWith(".")) { + throw new IllegalArgumentException( + "CreateAnomalyDetectionTool doesn't support searching indices starting with '.' since it could be system index, current searching index name: " + + indexName + ); + } + + GetMappingsRequest getMappingsRequest = new GetMappingsRequest().indices(indexName); + client.admin().indices().getMappings(getMappingsRequest, ActionListener.wrap(response -> { + Map mappings = response.getMappings(); + if (mappings.size() == 0) { + throw new IllegalArgumentException("No mapping found for the index: " + indexName); + } + + MappingMetadata mappingMetadata; + // when the index name is wildcard pattern, we fetch the mappings of the first index + if (indexName.contains("*")) { + mappingMetadata = mappings.get((String) mappings.keySet().toArray()[0]); + } else { + mappingMetadata = mappings.get(indexName); + } + + Map mappingSource = (Map) mappingMetadata.getSourceAsMap().get("properties"); + if (Objects.isNull(mappingSource)) { + throw new IllegalArgumentException( + "The index " + indexName + " doesn't have mapping metadata, please add data to it or using another index." + ); + } + + // flatten all the fields in the mapping + Map fieldsToType = new HashMap<>(); + ToolHelper.extractFieldNamesTypes(mappingSource, fieldsToType, "", true); + + // find all date type fields from the mapping + final Set dateFields = findDateTypeFields(fieldsToType); + if (dateFields.isEmpty()) { + throw new IllegalArgumentException( + "The index " + indexName + " doesn't have date type fields, cannot create an anomaly detector for it." + ); + } + StringJoiner dateFieldsJoiner = new StringJoiner(","); + dateFields.forEach(dateFieldsJoiner::add); + + // filter the mapping to improve the accuracy of the result + // only fields support aggregation can be existed in the mapping and sent to LLM + Map filteredMapping = fieldsToType + .entrySet() + .stream() + .filter(entry -> VALID_FIELD_TYPES.contains(entry.getValue())) + .collect(Collectors.toUnmodifiableMap(Map.Entry::getKey, Map.Entry::getValue)); + + // construct the prompt + String prompt = constructPrompt(filteredMapping, indexName); + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters(Collections.singletonMap("prompt", prompt)) + .build(); + ActionRequest request = new MLPredictionTaskRequest( + modelId, + MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), + null + ); + + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(mlTaskResponse -> { + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlTaskResponse.getOutput(); + ModelTensors modelTensors = modelTensorOutput.getMlModelOutputs().get(0); + ModelTensor modelTensor = modelTensors.getMlModelTensors().get(0); + Map dataAsMap = (Map) modelTensor.getDataAsMap(); + if (dataAsMap == null) { + listener.onFailure(new IllegalStateException("Remote endpoint fails to inference.")); + return; + } + String finalResponse = dataAsMap.get("response"); + if (Strings.isNullOrEmpty(finalResponse)) { + listener.onFailure(new IllegalStateException("Remote endpoint fails to inference, no response found.")); + return; + } + + // use regex pattern to extract the suggested parameters for the create anomaly detector API + Pattern pattern = Pattern.compile(EXTRACT_INFORMATION_REGEX); + Matcher matcher = pattern.matcher(finalResponse); + if (!matcher.matches()) { + log + .error( + "The inference result from remote endpoint is not valid because the result: [" + + finalResponse + + "] cannot match the regex: " + + EXTRACT_INFORMATION_REGEX + ); + listener + .onFailure( + new IllegalStateException( + "The inference result from remote endpoint is not valid, cannot extract the key information from the result." + ) + ); + return; + } + + // remove double quotes or whitespace if exists + String categoryField = matcher.group(1).replaceAll("\"", "").strip(); + String aggregationField = matcher.group(2).replaceAll("\"", "").strip(); + String aggregationMethod = matcher.group(3).replaceAll("\"", "").strip(); + + Map result = ImmutableMap + .of( + OUTPUT_KEY_INDEX, + indexName, + OUTPUT_KEY_CATEGORY_FIELD, + categoryField, + OUTPUT_KEY_AGGREGATION_FIELD, + aggregationField, + OUTPUT_KEY_AGGREGATION_METHOD, + aggregationMethod, + OUTPUT_KEY_DATE_FIELDS, + dateFieldsJoiner.toString() + ); + listener.onResponse((T) AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(result))); + }, e -> { + log.error("fail to predict model: " + e); + listener.onFailure(e); + })); + }, e -> { + log.error("failed to get mapping: " + e); + if (e.toString().contains("IndexNotFoundException")) { + listener + .onFailure( + new IllegalArgumentException( + "Return this final answer to human directly and do not use other tools: 'The index doesn't exist, please provide another index and retry'. Please try to directly send this message to human to ask for index name" + ) + ); + } else { + listener.onFailure(e); + } + })); + } + + /** + * Enrich the parameters by adding the parameters extracted from the chat + * @param parameters the original parameters + * @return the enriched parameters with parameters extracting from the chat + */ + private Map enrichParameters(Map parameters) { + Map result = new HashMap<>(parameters); + try { + // input is a map + Map chatParameters = gson.fromJson(parameters.get("input"), Map.class); + result.putAll(chatParameters); + } catch (Exception e) { + // input is a string + String indexName = parameters.getOrDefault("input", ""); + if (!indexName.isEmpty()) { + result.put("index", indexName); + } + } + return result; + } + + /** + * + * @param fieldsToType the flattened field-> field type mapping + * @return a list containing all the date type fields + */ + private Set findDateTypeFields(final Map fieldsToType) { + Set result = new HashSet<>(); + for (Map.Entry entry : fieldsToType.entrySet()) { + String value = entry.getValue(); + if (value.equals("date") || value.equals("date_nanos")) { + result.add(entry.getKey()); + } + } + return result; + } + + @SuppressWarnings("unchecked") + private static Map loadDefaultPromptFromFile() { + try (InputStream inputStream = CreateAnomalyDetectorTool.class.getResourceAsStream("CreateAnomalyDetectorDefaultPrompt.json")) { + if (inputStream != null) { + return gson.fromJson(new String(inputStream.readAllBytes(), StandardCharsets.UTF_8), Map.class); + } + } catch (IOException e) { + log.error("Failed to load prompt from the file CreateAnomalyDetectorDefaultPrompt.json, error: ", e); + } + return new HashMap<>(); + } + + /** + * + * @param fieldsToType the flattened field-> field type mapping + * @param indexName the index name + * @return the prompt about creating anomaly detector + */ + private String constructPrompt(final Map fieldsToType, final String indexName) { + StringJoiner tableInfoJoiner = new StringJoiner("\n"); + for (Map.Entry entry : fieldsToType.entrySet()) { + tableInfoJoiner.add("- " + entry.getKey() + ": " + entry.getValue()); + } + + Map indexInfo = ImmutableMap.of("indexName", indexName, "indexMapping", tableInfoJoiner.toString()); + StringSubstitutor substitutor = new StringSubstitutor(indexInfo, "${indexInfo.", "}"); + return substitutor.replace(contextPrompt); + } + + /** + * + * @param parameters the input parameters + * @return false if the input parameters is null or empty + */ + @Override + public boolean validate(Map parameters) { + return parameters != null && parameters.size() != 0; + } + + /** + * + * @return the type of this tool + */ + @Override + public String getType() { + return TYPE; + } + + /** + * The tool factory + */ + public static class Factory implements Tool.Factory { + private Client client; + + private static CreateAnomalyDetectorTool.Factory INSTANCE; + + /** + * Create or return the singleton factory instance + */ + public static CreateAnomalyDetectorTool.Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (CreateAnomalyDetectorTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new CreateAnomalyDetectorTool.Factory(); + return INSTANCE; + } + } + + public void init(Client client) { + this.client = client; + } + + /** + * + * @param map the input parameters + * @return the instance of this tool + */ + @Override + public CreateAnomalyDetectorTool create(Map map) { + String modelId = (String) map.getOrDefault("model_id", ""); + if (modelId.isEmpty()) { + throw new IllegalArgumentException("model_id cannot be empty."); + } + String modelType = (String) map.getOrDefault("model_type", ModelType.CLAUDE.toString()); + if (!ModelType.OPENAI.toString().equalsIgnoreCase(modelType) && !ModelType.CLAUDE.toString().equalsIgnoreCase(modelType)) { + throw new IllegalArgumentException("Unsupported model_type: " + modelType); + } + return new CreateAnomalyDetectorTool(client, modelId, modelType); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } + } +} diff --git a/src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java b/src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java new file mode 100644 index 00000000..60168603 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/NeuralSparseSearchTool.java @@ -0,0 +1,151 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.Map; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.client.Client; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * This tool supports neural_sparse search with sparse encoding models and rank_features field. + */ +@Log4j2 +@Getter +@Setter +@ToolAnnotation(NeuralSparseSearchTool.TYPE) +public class NeuralSparseSearchTool extends AbstractRetrieverTool { + public static final String TYPE = "NeuralSparseSearchTool"; + public static final String MODEL_ID_FIELD = "model_id"; + public static final String EMBEDDING_FIELD = "embedding_field"; + public static final String NESTED_PATH_FIELD = "nested_path"; + + private String name = TYPE; + private String modelId; + private String embeddingField; + private String nestedPath; + + @Builder + public NeuralSparseSearchTool( + Client client, + NamedXContentRegistry xContentRegistry, + String index, + String embeddingField, + String[] sourceFields, + Integer docSize, + String modelId, + String nestedPath + ) { + super(client, xContentRegistry, index, sourceFields, docSize); + this.modelId = modelId; + this.embeddingField = embeddingField; + this.nestedPath = nestedPath; + } + + @Override + protected String getQueryBody(String queryText) { + if (StringUtils.isBlank(embeddingField) || StringUtils.isBlank(modelId)) { + throw new IllegalArgumentException( + "Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty." + ); + } + + Map queryBody; + if (StringUtils.isBlank(nestedPath)) { + queryBody = Map + .of("query", Map.of("neural_sparse", Map.of(embeddingField, Map.of("query_text", queryText, "model_id", modelId)))); + } else { + queryBody = Map + .of( + "query", + Map + .of( + "nested", + Map + .of( + "path", + nestedPath, + "score_mode", + "max", + "query", + Map.of("neural_sparse", Map.of(embeddingField, Map.of("query_text", queryText, "model_id", modelId))) + ) + ) + ); + } + + try { + return AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(queryBody)); + } catch (PrivilegedActionException e) { + throw new RuntimeException(e); + } + } + + @Override + public String getType() { + return TYPE; + } + + public static class Factory extends AbstractRetrieverTool.Factory { + private static Factory INSTANCE; + + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (NeuralSparseSearchTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + @Override + public NeuralSparseSearchTool create(Map params) { + String index = (String) params.get(INDEX_FIELD); + String embeddingField = (String) params.get(EMBEDDING_FIELD); + String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class); + String modelId = (String) params.get(MODEL_ID_FIELD); + Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : DEFAULT_DOC_SIZE; + String nestedPath = (String) params.get(NESTED_PATH_FIELD); + return NeuralSparseSearchTool + .builder() + .client(client) + .xContentRegistry(xContentRegistry) + .index(index) + .embeddingField(embeddingField) + .sourceFields(sourceFields) + .modelId(modelId) + .docSize(docSize) + .nestedPath(nestedPath) + .build(); + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } + } +} diff --git a/src/main/java/org/opensearch/agent/tools/PPLTool.java b/src/main/java/org/opensearch/agent/tools/PPLTool.java index 052c6d3c..621426ac 100644 --- a/src/main/java/org/opensearch/agent/tools/PPLTool.java +++ b/src/main/java/org/opensearch/agent/tools/PPLTool.java @@ -5,32 +5,37 @@ package org.opensearch.agent.tools; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.UncheckedIOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; +import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; +import java.util.Set; import java.util.StringJoiner; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.math.NumberUtils; +import org.apache.commons.text.StringSubstitutor; import org.json.JSONObject; import org.opensearch.action.ActionRequest; import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest; -import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; +import org.opensearch.agent.tools.utils.ToolHelper; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.MappingMetadata; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; -import org.opensearch.core.common.io.stream.InputStreamStreamInput; -import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; @@ -40,7 +45,6 @@ import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; -import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; @@ -58,6 +62,8 @@ import lombok.extern.log4j.Log4j2; @Log4j2 +@Setter +@Getter @ToolAnnotation(PPLTool.TYPE) public class PPLTool implements Tool { @@ -66,7 +72,8 @@ public class PPLTool implements Tool { @Setter private Client client; - private static final String DEFAULT_DESCRIPTION = "Use this tool to generate PPL and execute."; + private static final String DEFAULT_DESCRIPTION = + "\"Use this tool when user ask question based on the data in the cluster or parse user statement about which index to use in a conversion.\nAlso use this tool when question only contains index information.\n1. If uesr question contain both question and index name, the input parameters are {'question': UserQuestion, 'index': IndexName}.\n2. If user question contain only question, the input parameter is {'question': UserQuestion}.\n3. If uesr question contain only index name, find the original human input from the conversation histroy and formulate parameter as {'question': UserQuestion, 'index': IndexName}\nThe index name should be exactly as stated in user's input."; @Setter @Getter @@ -81,26 +88,120 @@ public class PPLTool implements Tool { private String contextPrompt; - private static Gson gson = new Gson(); + private Boolean execute; - public PPLTool(Client client, String modelId, String contextPrompt) { + private PPLModelType pplModelType; + + private String previousToolKey; + + private int head; + + private static Gson gson = org.opensearch.ml.common.utils.StringUtils.gson; + + private static Map DEFAULT_PROMPT_DICT; + + private static Set ALLOWED_FIELDS_TYPE; + + static { + ALLOWED_FIELDS_TYPE = new HashSet<>(); // from + // https://github.com/opensearch-project/sql/blob/2.x/docs/user/ppl/general/datatypes.rst#data-types-mapping + // and https://opensearch.org/docs/latest/field-types/supported-field-types/index/ + ALLOWED_FIELDS_TYPE.add("boolean"); + ALLOWED_FIELDS_TYPE.add("byte"); + ALLOWED_FIELDS_TYPE.add("short"); + ALLOWED_FIELDS_TYPE.add("integer"); + ALLOWED_FIELDS_TYPE.add("long"); + ALLOWED_FIELDS_TYPE.add("float"); + ALLOWED_FIELDS_TYPE.add("half_float"); + ALLOWED_FIELDS_TYPE.add("scaled_float"); + ALLOWED_FIELDS_TYPE.add("double"); + ALLOWED_FIELDS_TYPE.add("keyword"); + ALLOWED_FIELDS_TYPE.add("text"); + ALLOWED_FIELDS_TYPE.add("date"); + ALLOWED_FIELDS_TYPE.add("date_nanos"); + ALLOWED_FIELDS_TYPE.add("ip"); + ALLOWED_FIELDS_TYPE.add("binary"); + ALLOWED_FIELDS_TYPE.add("object"); + ALLOWED_FIELDS_TYPE.add("nested"); + ALLOWED_FIELDS_TYPE.add("geo_point"); + + DEFAULT_PROMPT_DICT = loadDefaultPromptDict(); + } + + public enum PPLModelType { + CLAUDE, + FINETUNE, + OPENAI; + + public static PPLModelType from(String value) { + if (value.isEmpty()) { + return PPLModelType.CLAUDE; + } + try { + return PPLModelType.valueOf(value.toUpperCase(Locale.ROOT)); + } catch (Exception e) { + log.error("Wrong PPL Model type, should be CLAUDE, FINETUNE, or OPENAI"); + return PPLModelType.CLAUDE; + } + } + + } + + public PPLTool( + Client client, + String modelId, + String contextPrompt, + String pplModelType, + String previousToolKey, + int head, + boolean execute + ) { this.client = client; this.modelId = modelId; - this.contextPrompt = contextPrompt; + this.pplModelType = PPLModelType.from(pplModelType); + if (contextPrompt.isEmpty()) { + this.contextPrompt = DEFAULT_PROMPT_DICT.getOrDefault(this.pplModelType.toString(), ""); + } else { + this.contextPrompt = contextPrompt; + } + this.previousToolKey = previousToolKey; + this.head = head; + this.execute = execute; } + @SuppressWarnings("unchecked") @Override public void run(Map parameters, ActionListener listener) { - String indexName = parameters.get("index"); + extractFromChatParameters(parameters); + String indexName = getIndexNameFromParameters(parameters); + if (StringUtils.isBlank(indexName)) { + throw new IllegalArgumentException( + "Return this final answer to human directly and do not use other tools: 'Please provide index name'. Please try to directly send this message to human to ask for index name" + ); + } String question = parameters.get("question"); - SearchRequest searchRequest = buildSearchRequest(indexName); + if (StringUtils.isBlank(indexName) || StringUtils.isBlank(question)) { + throw new IllegalArgumentException("Parameter index and question can not be null or empty."); + } + if (indexName.startsWith(".")) { + throw new IllegalArgumentException( + "PPLTool doesn't support searching indices starting with '.' since it could be system index, current searching index name: " + + indexName + ); + } + GetMappingsRequest getMappingsRequest = buildGetMappingRequest(indexName); - client.admin().indices().getMappings(getMappingsRequest, ActionListener.wrap(getMappingsResponse -> { + client.admin().indices().getMappings(getMappingsRequest, ActionListener.wrap(getMappingsResponse -> { Map mappings = getMappingsResponse.getMappings(); - client.search(searchRequest, ActionListener.wrap(searchResponse -> { + if (mappings.isEmpty()) { + throw new IllegalArgumentException("No matching mapping with index name: " + indexName); + } + String firstIndexName = (String) mappings.keySet().toArray()[0]; + SearchRequest searchRequest = buildSearchRequest(firstIndexName); + client.search(searchRequest, ActionListener.wrap(searchResponse -> { SearchHit[] searchHits = searchResponse.getHits().getHits(); - String tableInfo = constructTableInfo(searchHits, mappings, indexName); - String prompt = constructPrompt(tableInfo, question, indexName); + String tableInfo = constructTableInfo(searchHits, mappings); + String prompt = constructPrompt(tableInfo, question.strip(), indexName); RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet .builder() .parameters(Collections.singletonMap("prompt", prompt)) @@ -109,12 +210,17 @@ public void run(Map parameters, ActionListener listener) modelId, MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build() ); - client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(mlTaskResponse -> { + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(mlTaskResponse -> { ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlTaskResponse.getOutput(); ModelTensors modelTensors = modelTensorOutput.getMlModelOutputs().get(0); ModelTensor modelTensor = modelTensors.getMlModelTensors().get(0); Map dataAsMap = (Map) modelTensor.getDataAsMap(); - String ppl = dataAsMap.get("output"); + String ppl = parseOutput(dataAsMap.get("response"), indexName); + if (!this.execute) { + Map ret = ImmutableMap.of("ppl", ppl); + listener.onResponse((T) AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(ret))); + return; + } JSONObject jsonContent = new JSONObject(ImmutableMap.of("query", ppl)); PPLQueryRequest pplQueryRequest = new PPLQueryRequest(ppl, jsonContent, null, "jdbc"); TransportPPLQueryRequest transportPPLQueryRequest = new TransportPPLQueryRequest(pplQueryRequest); @@ -122,7 +228,7 @@ public void run(Map parameters, ActionListener listener) .execute( PPLQueryAction.INSTANCE, transportPPLQueryRequest, - getPPLTransportActionListener(ActionListener.wrap(transportPPLQueryResponse -> { + getPPLTransportActionListener(ActionListener.wrap(transportPPLQueryResponse -> { String results = transportPPLQueryResponse.getResult(); Map returnResults = ImmutableMap.of("ppl", ppl, "executionResult", results); listener @@ -138,18 +244,26 @@ public void run(Map parameters, ActionListener listener) ); // Execute output here }, e -> { - log.info("fail to predict model: " + e); + log.error(String.format(Locale.ROOT, "fail to predict model: %s with error: %s", modelId, e.getMessage()), e); listener.onFailure(e); })); }, e -> { - log.info("fail to search: " + e); + log.error(String.format(Locale.ROOT, "fail to search model: %s with error: %s", modelId, e.getMessage()), e); listener.onFailure(e); - } - - )); + })); }, e -> { - log.info("fail to get mapping: " + e); - listener.onFailure(e); + log.error(String.format(Locale.ROOT, "fail to get mapping of index: %s with error: %s", indexName, e.getMessage()), e); + String errorMessage = e.getMessage(); + if (errorMessage.contains("no such index")) { + listener + .onFailure( + new IllegalArgumentException( + "Return this final answer to human directly and do not use other tools: 'Please provide index name'. Please try to directly send this message to human to ask for index name" + ) + ); + } else { + listener.onFailure(e); + } })); } @@ -165,10 +279,7 @@ public String getName() { @Override public boolean validate(Map parameters) { - if (parameters == null || parameters.size() == 0) { - return false; - } - return true; + return parameters != null && !parameters.isEmpty(); } public static class Factory implements Tool.Factory { @@ -195,21 +306,40 @@ public void init(Client client) { @Override public PPLTool create(Map map) { - return new PPLTool(client, (String) map.get("model_id"), (String) map.get("prompt")); + validatePPLToolParameters(map); + return new PPLTool( + client, + (String) map.get("model_id"), + (String) map.getOrDefault("prompt", ""), + (String) map.getOrDefault("model_type", ""), + (String) map.getOrDefault("previous_tool_name", ""), + NumberUtils.toInt((String) map.get("head"), -1), + Boolean.parseBoolean((String) map.getOrDefault("execute", "true")) + ); } @Override public String getDefaultDescription() { return DEFAULT_DESCRIPTION; } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } + } private SearchRequest buildSearchRequest(String indexName) { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.size(1).query(new MatchAllQueryBuilder()); // client; - SearchRequest request = new SearchRequest(new String[] { indexName }, searchSourceBuilder); - return request; + return new SearchRequest(new String[] { indexName }, searchSourceBuilder); } private GetMappingsRequest buildGetMappingRequest(String indexName) { @@ -219,13 +349,41 @@ private GetMappingsRequest buildGetMappingRequest(String indexName) { return getMappingsRequest; } - private String constructTableInfo(SearchHit[] searchHits, Map mappings, String indexName) - throws PrivilegedActionException { - MappingMetadata mappingMetadata = mappings.get(indexName); + private static void validatePPLToolParameters(Map map) { + if (StringUtils.isBlank((String) map.get("model_id"))) { + throw new IllegalArgumentException("PPL tool needs non blank model id."); + } + if (map.containsKey("execute") && Objects.nonNull(map.get("execute"))) { + String execute = map.get("execute").toString().toLowerCase(Locale.ROOT); + if (!execute.equals("true") && !execute.equals("false")) { + throw new IllegalArgumentException("PPL tool parameter execute must be false or true"); + } + + } + if (map.containsKey("head")) { + String head = map.get("head").toString(); + try { + int headInt = NumberUtils.createInteger(head); + } catch (Exception e) { + throw new IllegalArgumentException("PPL tool parameter head must be integer."); + } + } + } + + private String constructTableInfo(SearchHit[] searchHits, Map mappings) throws PrivilegedActionException { + String firstIndexName = (String) mappings.keySet().toArray()[0]; + MappingMetadata mappingMetadata = mappings.get(firstIndexName); Map mappingSource = (Map) mappingMetadata.getSourceAsMap().get("properties"); + if (Objects.isNull(mappingSource)) { + throw new IllegalArgumentException( + "The querying index doesn't have mapping metadata, please add data to it or using another index." + ); + } Map fieldsToType = new HashMap<>(); - extractNamesTypes(mappingSource, fieldsToType, ""); + ToolHelper.extractFieldNamesTypes(mappingSource, fieldsToType, "", false); StringJoiner tableInfoJoiner = new StringJoiner("\n"); + List sortedKeys = new ArrayList<>(fieldsToType.keySet()); + Collections.sort(sortedKeys); if (searchHits.length > 0) { SearchHit hit = searchHits[0]; @@ -236,48 +394,33 @@ private String constructTableInfo(SearchHit[] searchHits, Map mappingSource, Map fieldsToType, String prefix) { - if (prefix.length() > 0) { - prefix += "."; - } - - for (Map.Entry entry : mappingSource.entrySet()) { - String n = entry.getKey(); - Object v = entry.getValue(); - - if (v instanceof Map) { - Map vMap = (Map) v; - if (vMap.containsKey("type")) { - fieldsToType.put(prefix + n, (String) vMap.get("type")); - } else if (vMap.containsKey("properties")) { - extractNamesTypes((Map) vMap.get("properties"), fieldsToType, prefix + n); - } - } - } + Map indexInfo = ImmutableMap.of("mappingInfo", tableInfo, "question", question, "indexName", indexName); + StringSubstitutor substitutor = new StringSubstitutor(indexInfo, "${indexInfo.", "}"); + return substitutor.replace(contextPrompt); } private static void extractSamples(Map sampleSource, Map fieldsToSample, String prefix) throws PrivilegedActionException { - if (prefix.length() > 0) { + if (!prefix.isEmpty()) { prefix += "."; } @@ -297,22 +440,93 @@ private static void extractSamples(Map sampleSource, Map ActionListener getPPLTransportActionListener(ActionListener listener) { - return ActionListener.wrap(r -> { listener.onResponse(fromActionResponse(r)); }, listener::onFailure); + return ActionListener.wrap(r -> { listener.onResponse(TransportPPLQueryResponse.fromActionResponse(r)); }, listener::onFailure); } - private static TransportPPLQueryResponse fromActionResponse(ActionResponse actionResponse) { - if (actionResponse instanceof TransportPPLQueryResponse) { - return (TransportPPLQueryResponse) actionResponse; + @SuppressWarnings("unchecked") + private void extractFromChatParameters(Map parameters) { + if (parameters.containsKey("input")) { + String input = parameters.get("input"); + try { + Map chatParameters = gson.fromJson(input, Map.class); + parameters.putAll(chatParameters); + } catch (Exception e) { + log.error(String.format(Locale.ROOT, "Failed to parse chat parameters, input is: %s, which is not a valid json", input), e); + } } + } + + private String parseOutput(String llmOutput, String indexName) { + String ppl; + Pattern pattern = Pattern.compile("((.|[\\r\\n])+?)"); // For ppl like source=a \n | fields b + Matcher matcher = pattern.matcher(llmOutput); + + if (matcher.find()) { + ppl = matcher.group(1).replaceAll("[\\r\\n]", "").replaceAll("ISNOTNULL", "isnotnull").trim(); + } else { // logic for only ppl returned + int sourceIndex = llmOutput.indexOf("source="); + int describeIndex = llmOutput.indexOf("describe "); + if (sourceIndex != -1) { + llmOutput = llmOutput.substring(sourceIndex); + + // Splitting the string at "|" + String[] lists = llmOutput.split("\\|"); + + // Modifying the first element + if (lists.length > 0) { + lists[0] = "source=" + indexName; + } + + // Joining the string back together + ppl = String.join("|", lists); + } else if (describeIndex != -1) { + llmOutput = llmOutput.substring(describeIndex); + String[] lists = llmOutput.split("\\|"); - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { - actionResponse.writeTo(osso); - try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { - return new TransportPPLQueryResponse(input); + // Modifying the first element + if (lists.length > 0) { + lists[0] = "describe " + indexName; + } + + // Joining the string back together + ppl = String.join("|", lists); + } else { + throw new IllegalArgumentException("The returned PPL: " + llmOutput + " has wrong format"); } - } catch (IOException e) { - throw new UncheckedIOException("failed to parse ActionResponse into TransportPPLQueryResponse", e); } + if (this.pplModelType != PPLModelType.FINETUNE) { + ppl = ppl.replace("`", ""); + } + ppl = ppl.replaceAll("\\bSPAN\\(", "span("); + if (this.head > 0) { + String[] lists = llmOutput.split("\\|"); + String lastCommand = lists[lists.length - 1].strip(); + if (!lastCommand.toLowerCase(Locale.ROOT).startsWith("head")) // not handle cases source=...| ... | head 5 | head + { + ppl = ppl + " | head " + this.head; + } + } + return ppl; + } + private String getIndexNameFromParameters(Map parameters) { + String indexName = parameters.getOrDefault("index", ""); + if (!StringUtils.isBlank(this.previousToolKey) && StringUtils.isBlank(indexName)) { + indexName = parameters.getOrDefault(this.previousToolKey + ".output", ""); // read index name from previous key + } + return indexName.trim(); + } + + @SuppressWarnings("unchecked") + private static Map loadDefaultPromptDict() { + try (InputStream searchResponseIns = PPLTool.class.getResourceAsStream("PPLDefaultPrompt.json")) { + if (searchResponseIns != null) { + String defaultPromptContent = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); + return gson.fromJson(defaultPromptContent, Map.class); + } + } catch (IOException e) { + log.error("Failed to load default prompt dict", e); + } + return new HashMap<>(); } } diff --git a/src/main/java/org/opensearch/agent/tools/RAGTool.java b/src/main/java/org/opensearch/agent/tools/RAGTool.java new file mode 100644 index 00000000..e8159839 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/RAGTool.java @@ -0,0 +1,273 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.apache.commons.lang3.StringEscapeUtils.escapeJson; +import static org.opensearch.agent.tools.AbstractRetrieverTool.*; +import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.common.utils.StringUtils.toJson; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.action.ActionRequest; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; + +import com.google.gson.Gson; +import com.google.gson.JsonObject; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * This tool supports retrieving helpful information to optimize the output of the large language model to answer questions.. + */ +@Log4j2 +@Setter +@Getter +@ToolAnnotation(RAGTool.TYPE) +public class RAGTool implements Tool { + public static final String TYPE = "RAGTool"; + public static String DEFAULT_DESCRIPTION = + "Use this tool to retrieve helpful information to optimize the output of the large language model to answer questions."; + public static final String INFERENCE_MODEL_ID_FIELD = "inference_model_id"; + public static final String EMBEDDING_MODEL_ID_FIELD = "embedding_model_id"; + public static final String INDEX_FIELD = "index"; + public static final String SOURCE_FIELD = "source_field"; + public static final String DOC_SIZE_FIELD = "doc_size"; + public static final String EMBEDDING_FIELD = "embedding_field"; + public static final String OUTPUT_FIELD = "output_field"; + public static final String QUERY_TYPE = "query_type"; + public static final String CONTENT_GENERATION_FIELD = "enable_content_generation"; + public static final String K_FIELD = "k"; + private final AbstractRetrieverTool queryTool; + private String name = TYPE; + private String description = DEFAULT_DESCRIPTION; + private Client client; + private String inferenceModelId; + private Boolean enableContentGeneration; + private NamedXContentRegistry xContentRegistry; + private String queryType; + @Setter + private Parser inputParser; + @Setter + private Parser outputParser; + + @Builder + public RAGTool( + Client client, + NamedXContentRegistry xContentRegistry, + String inferenceModelId, + Boolean enableContentGeneration, + AbstractRetrieverTool queryTool + ) { + this.client = client; + this.xContentRegistry = xContentRegistry; + this.inferenceModelId = inferenceModelId; + this.enableContentGeneration = enableContentGeneration; + this.queryTool = queryTool; + outputParser = new Parser() { + @Override + public Object parse(Object o) { + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + } + }; + } + + public void run(Map parameters, ActionListener listener) { + String input = null; + + if (!this.validate(parameters)) { + throw new IllegalArgumentException("[" + INPUT_FIELD + "] is null or empty, can not process it."); + } + + try { + input = parameters.get(INPUT_FIELD); + } catch (Exception e) { + log.error("Failed to read question from " + INPUT_FIELD, e); + listener.onFailure(new IllegalArgumentException("Failed to read question from " + INPUT_FIELD)); + return; + } + + String embeddingInput = input; + ActionListener actionListener = ActionListener.wrap(r -> { + T queryToolOutput; + if (!this.enableContentGeneration) { + listener.onResponse(r); + } + if (r.equals("Can not get any match from search result.")) { + queryToolOutput = (T) ""; + } else { + Gson gson = new Gson(); + String[] hits = r.toString().split("\n"); + + StringBuilder resultBuilder = new StringBuilder(); + for (String hit : hits) { + JsonObject jsonObject = gson.fromJson(hit, JsonObject.class); + String id = jsonObject.get("_id").getAsString(); + JsonObject source = jsonObject.getAsJsonObject("_source"); + + resultBuilder.append("_id: ").append(id).append("\n"); + resultBuilder.append("_source: ").append(source.toString()).append("\n"); + } + + queryToolOutput = (T) gson.toJson(resultBuilder.toString()); + } + + Map tmpParameters = new HashMap<>(); + tmpParameters.putAll(parameters); + + if (queryToolOutput instanceof String) { + tmpParameters.put(OUTPUT_FIELD, (String) queryToolOutput); + } else { + tmpParameters.put(OUTPUT_FIELD, escapeJson(toJson(queryToolOutput.toString()))); + } + + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(); + ActionRequest request = new MLPredictionTaskRequest(this.inferenceModelId, mlInput, null); + + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(resp -> { + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) resp.getOutput(); + modelTensorOutput.getMlModelOutputs(); + if (outputParser == null) { + listener.onResponse((T) modelTensorOutput.getMlModelOutputs()); + } else { + listener.onResponse((T) outputParser.parse(modelTensorOutput.getMlModelOutputs())); + } + }, e -> { + log.error("Failed to run model " + this.inferenceModelId, e); + listener.onFailure(e); + })); + }, e -> { + log.error("Failed to search index.", e); + listener.onFailure(e); + }); + this.queryTool.run(Map.of(INPUT_FIELD, embeddingInput), actionListener); + } + + public String getType() { + return TYPE; + } + + @Override + public String getVersion() { + return null; + } + + public String getName() { + return this.name; + } + + public void setName(String s) { + this.name = s; + } + + public boolean validate(Map parameters) { + if (parameters == null || parameters.size() == 0) { + return false; + } + String question = parameters.get(INPUT_FIELD); + return question != null && !question.trim().isEmpty(); + } + + /** + * Factory class to create RAGTool + */ + public static class Factory implements Tool.Factory { + private Client client; + private NamedXContentRegistry xContentRegistry; + + private static Factory INSTANCE; + + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (RAGTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + public void init(Client client, NamedXContentRegistry xContentRegistry) { + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + public RAGTool create(Map params) { + String queryType = params.containsKey(QUERY_TYPE) ? (String) params.get(QUERY_TYPE) : "neural"; + String embeddingModelId = (String) params.get(EMBEDDING_MODEL_ID_FIELD); + Boolean enableContentGeneration = params.containsKey(CONTENT_GENERATION_FIELD) + ? Boolean.parseBoolean((String) params.get(CONTENT_GENERATION_FIELD)) + : true; + String inferenceModelId = enableContentGeneration ? (String) params.get(INFERENCE_MODEL_ID_FIELD) : ""; + switch (queryType) { + case "neural_sparse": + params.put(NeuralSparseSearchTool.MODEL_ID_FIELD, embeddingModelId); + NeuralSparseSearchTool neuralSparseSearchTool = NeuralSparseSearchTool.Factory.getInstance().create(params); + return RAGTool + .builder() + .client(client) + .xContentRegistry(xContentRegistry) + .inferenceModelId(inferenceModelId) + .enableContentGeneration(enableContentGeneration) + .queryTool(neuralSparseSearchTool) + .build(); + case "neural": + params.put(VectorDBTool.MODEL_ID_FIELD, embeddingModelId); + VectorDBTool vectorDBTool = VectorDBTool.Factory.getInstance().create(params); + return RAGTool + .builder() + .client(client) + .xContentRegistry(xContentRegistry) + .inferenceModelId(inferenceModelId) + .enableContentGeneration(enableContentGeneration) + .queryTool(vectorDBTool) + .build(); + default: + log.error("Failed to read queryType, please input neural_sparse or neural."); + throw new IllegalArgumentException("Failed to read queryType, please input neural_sparse or neural."); + } + + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } + } +} diff --git a/src/main/java/org/opensearch/agent/tools/SearchAlertsTool.java b/src/main/java/org/opensearch/agent/tools/SearchAlertsTool.java new file mode 100644 index 00000000..7be36955 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/SearchAlertsTool.java @@ -0,0 +1,194 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.List; +import java.util.Map; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.client.Client; +import org.opensearch.client.node.NodeClient; +import org.opensearch.commons.alerting.AlertingPluginInterface; +import org.opensearch.commons.alerting.action.GetAlertsRequest; +import org.opensearch.commons.alerting.action.GetAlertsResponse; +import org.opensearch.commons.alerting.model.Alert; +import org.opensearch.commons.alerting.model.Table; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@ToolAnnotation(SearchAlertsTool.TYPE) +public class SearchAlertsTool implements Tool { + public static final String TYPE = "SearchAlertsTool"; + private static final String DEFAULT_DESCRIPTION = + "This is a tool that finds alerts. It takes 12 optional argument named sortOrder which defines the order of the results (options are asc or desc, and default is asc), and sortString which defines how to sort the results (default is monitor_name.keyword), and size which defines the size of the request to be returned (default is 20), and startIndex which defines the paginated index to start from (default is 0), and searchString which defines the search string to use for searching a specific alert (default is an empty String), and severityLevel which defines the severity level to filter for as an integer (default is ALL), and alertState which defines the alert state to filter for (options are ALL, ACTIVE, ERROR, COMPLETED, or ACKNOWLEDGED, default is ALL), and monitorId which defines the associated monitor ID to filter for, and alertIndex which defines the alert index to search from (default is null), and monitorIds which defines the list of monitor IDs to filter for, and workflowIds which defines the list of workflow IDs to filter for(default is null), and alertIds which defines the list of alert IDs to filter for (default is null). The tool returns 2 values: a list of alerts (each containining the alert id, version, schema version, monitor ID, workflow ID, workflow name, monitor name, monitor version, monitor user, trigger ID, trigger name, finding IDs, related doc IDs, state, start time in epoch milliseconds, end time in epoch milliseconds, last notification time in epoch milliseconds, acknowledged time in epoch milliseconds, error message, error history, severity, action execution results, aggregation result bucket, execution ID, associated alert IDs), and the total number of alerts."; + + @Setter + @Getter + private String name = TYPE; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + @Getter + private String type; + @Getter + private String version; + + private Client client; + @Setter + private Parser inputParser; + @Setter + private Parser outputParser; + + public SearchAlertsTool(Client client) { + this.client = client; + + // probably keep this overridden output parser. need to ensure the output matches what's expected + outputParser = new Parser<>() { + @Override + public Object parse(Object o) { + @SuppressWarnings("unchecked") + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + } + }; + } + + @Override + public void run(Map parameters, ActionListener listener) { + final String tableSortOrder = parameters.getOrDefault("sortOrder", "asc"); + final String tableSortString = parameters.getOrDefault("sortString", "monitor_name.keyword"); + final int tableSize = parameters.containsKey("size") && StringUtils.isNumeric(parameters.get("size")) + ? Integer.parseInt(parameters.get("size")) + : 20; + final int startIndex = parameters.containsKey("startIndex") && StringUtils.isNumeric(parameters.get("startIndex")) + ? Integer.parseInt(parameters.get("startIndex")) + : 0; + final String searchString = parameters.getOrDefault("searchString", null); + + // not exposing "missing" from the table, using default of null + final Table table = new Table(tableSortOrder, tableSortString, null, tableSize, startIndex, searchString); + + final String severityLevel = parameters.getOrDefault("severityLevel", "ALL"); + final String alertState = parameters.getOrDefault("alertState", "ALL"); + final String monitorId = parameters.getOrDefault("monitorId", null); + final String alertIndex = parameters.getOrDefault("alertIndex", null); + @SuppressWarnings("unchecked") + final List monitorIds = parameters.containsKey("monitorIds") + ? gson.fromJson(parameters.get("monitorIds"), List.class) + : null; + @SuppressWarnings("unchecked") + final List workflowIds = parameters.containsKey("workflowIds") + ? gson.fromJson(parameters.get("workflowIds"), List.class) + : null; + @SuppressWarnings("unchecked") + final List alertIds = parameters.containsKey("alertIds") ? gson.fromJson(parameters.get("alertIds"), List.class) : null; + + GetAlertsRequest getAlertsRequest = new GetAlertsRequest( + table, + severityLevel, + alertState, + monitorId, + alertIndex, + monitorIds, + workflowIds, + alertIds, + null + ); + + // create response listener + // stringify the response, may change to a standard format in the future + ActionListener getAlertsListener = ActionListener.wrap(response -> { + StringBuilder sb = new StringBuilder(); + sb.append("Alerts=["); + for (Alert alert : response.getAlerts()) { + sb.append(alert.toString()); + } + sb.append("]"); + sb.append("TotalAlerts=").append(response.getTotalAlerts()); + listener.onResponse((T) sb.toString()); + }, e -> { + log.error("Failed to search alerts.", e); + listener.onFailure(e); + }); + + // execute the search + AlertingPluginInterface.INSTANCE.getAlerts((NodeClient) client, getAlertsRequest, getAlertsListener); + } + + @Override + public boolean validate(Map parameters) { + return true; + } + + @Override + public String getType() { + return TYPE; + } + + /** + * Factory for the {@link SearchAlertsTool} + */ + public static class Factory implements Tool.Factory { + private Client client; + + private static Factory INSTANCE; + + /** + * Create or return the singleton factory instance + */ + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (SearchAlertsTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + /** + * Initialize this factory + * @param client The OpenSearch client + */ + public void init(Client client) { + this.client = client; + } + + @Override + public SearchAlertsTool create(Map map) { + return new SearchAlertsTool(client); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } + } + +} diff --git a/src/main/java/org/opensearch/agent/tools/SearchAnomalyDetectorsTool.java b/src/main/java/org/opensearch/agent/tools/SearchAnomalyDetectorsTool.java new file mode 100644 index 00000000..c8e4ab8a --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/SearchAnomalyDetectorsTool.java @@ -0,0 +1,332 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.client.AnomalyDetectionNodeClient; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.transport.GetAnomalyDetectorResponse; +import org.opensearch.agent.tools.utils.ToolConstants; +import org.opensearch.agent.tools.utils.ToolConstants.DetectorStateString; +import org.opensearch.client.Client; +import org.opensearch.common.lucene.uid.Versions; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.query.WildcardQueryBuilder; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.timeseries.transport.GetConfigRequest; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@ToolAnnotation(SearchAnomalyDetectorsTool.TYPE) +public class SearchAnomalyDetectorsTool implements Tool { + public static final String TYPE = "SearchAnomalyDetectorsTool"; + private static final String DEFAULT_DESCRIPTION = + "This is a tool that searches anomaly detectors. It takes 12 optional arguments named detectorName which is the explicit name of the detector (default is null), and detectorNamePattern which is a wildcard query to match detector name (default is null), and indices which defines the index or index pattern the detector is detecting over (default is null), and highCardinality which defines whether the anomaly detector is high cardinality (synonymous with multi-entity) of non-high-cardinality (synonymous with single-entity) (default is null, indicating both), and lastUpdateTime which defines the latest update time of the anomaly detector in epoch milliseconds (default is null), and sortOrder which defines the order of the results (options are asc or desc, and default is asc), and sortString which defines how to sort the results (default is name.keyword), and size which defines the size of the request to be returned (default is 20), and startIndex which defines the paginated index to start from (default is 0), and running which defines whether the anomaly detector is running (default is null, indicating both), and failed which defines whether the anomaly detector has failed (default is null, indicating both). The tool returns 2 values: a list of anomaly detectors (each containing the detector id, detector name, detector type indicating multi-entity or single-entity (where multi-entity also means high-cardinality), detector description, name of the configured index, last update time in epoch milliseconds), and the total number of anomaly detectors."; + + @Setter + @Getter + private String name = TYPE; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + + @Getter + private String version; + + private Client client; + + private AnomalyDetectionNodeClient adClient; + + @Setter + private Parser inputParser; + @Setter + private Parser outputParser; + + public SearchAnomalyDetectorsTool(Client client, NamedWriteableRegistry namedWriteableRegistry) { + this.client = client; + this.adClient = new AnomalyDetectionNodeClient(client, namedWriteableRegistry); + + // probably keep this overridden output parser. need to ensure the output matches what's expected + outputParser = new Parser<>() { + @Override + public Object parse(Object o) { + @SuppressWarnings("unchecked") + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + } + }; + } + + // Response is currently in a simple string format including the list of anomaly detectors (only name and ID attached), and + // number of total detectors. The output will likely need to be updated, standardized, and include more fields in the + // future to cover a sufficient amount of potential questions the agent will need to handle. + @Override + public void run(Map parameters, ActionListener listener) { + final String detectorName = parameters.getOrDefault("detectorName", null); + final String detectorNamePattern = parameters.getOrDefault("detectorNamePattern", null); + final String indices = parameters.getOrDefault("indices", null); + final Boolean highCardinality = parameters.containsKey("highCardinality") + ? Boolean.parseBoolean(parameters.get("highCardinality")) + : null; + final Long lastUpdateTime = parameters.containsKey("lastUpdateTime") && StringUtils.isNumeric(parameters.get("lastUpdateTime")) + ? Long.parseLong(parameters.get("lastUpdateTime")) + : null; + final String sortOrderStr = parameters.getOrDefault("sortOrder", "asc"); + final SortOrder sortOrder = sortOrderStr.equalsIgnoreCase("asc") ? SortOrder.ASC : SortOrder.DESC; + final String sortString = parameters.getOrDefault("sortString", "name.keyword"); + final int size = parameters.containsKey("size") ? Integer.parseInt(parameters.get("size")) : 20; + final int startIndex = parameters.containsKey("startIndex") ? Integer.parseInt(parameters.get("startIndex")) : 0; + final Boolean running = parameters.containsKey("running") ? Boolean.parseBoolean(parameters.get("running")) : null; + final Boolean failed = parameters.containsKey("failed") ? Boolean.parseBoolean(parameters.get("failed")) : null; + + List mustList = new ArrayList(); + if (detectorName != null) { + mustList.add(new TermQueryBuilder("name.keyword", detectorName)); + } + if (detectorNamePattern != null) { + mustList.add(new WildcardQueryBuilder("name.keyword", detectorNamePattern)); + } + if (indices != null) { + mustList.add(new TermQueryBuilder("indices.keyword", indices)); + } + if (highCardinality != null) { + mustList.add(new TermQueryBuilder("detector_type", highCardinality ? "MULTI_ENTITY" : "SINGLE_ENTITY")); + } + if (lastUpdateTime != null) { + mustList.add(new BoolQueryBuilder().filter(new RangeQueryBuilder("last_update_time").gte(lastUpdateTime))); + + } + + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.must().addAll(mustList); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(boolQueryBuilder) + .size(size) + .from(startIndex) + .sort(sortString, sortOrder); + + SearchRequest searchDetectorRequest = new SearchRequest().source(searchSourceBuilder).indices(ToolConstants.AD_DETECTORS_INDEX); + + ActionListener searchDetectorListener = ActionListener.wrap(response -> { + StringBuilder sb = new StringBuilder(); + List hits = Arrays.asList(response.getHits().getHits()); + Map hitsAsMap = new HashMap<>(); + // We persist the hits map using detector name as the key. Note this is required to be unique from the AD plugin. + // We cannot use detector ID, because the detector in the response from the profile transport action does not include this, + // making it difficult to map potential hits that should be removed later on based on the profile response's detector state. + for (SearchHit hit : hits) { + hitsAsMap.put((String) hit.getSourceAsMap().get("name"), hit); + } + + // If we need to filter by detector state, make subsequent profile API calls to each detector + if (running != null || failed != null) { + List> profileFutures = new ArrayList<>(); + for (SearchHit hit : hits) { + CompletableFuture profileFuture = new CompletableFuture() + .orTimeout(30000, TimeUnit.MILLISECONDS); + profileFutures.add(profileFuture); + ActionListener profileListener = ActionListener + .wrap(profileResponse -> { + profileFuture.complete(profileResponse); + }, e -> { + log.error("Failed to get anomaly detector profile.", e); + profileFuture.completeExceptionally(e); + listener.onFailure(e); + }); + + GetConfigRequest profileRequest = new GetConfigRequest( + hit.getId(), + Versions.MATCH_ANY, + false, + true, + "", + "", + false, + null + ); + adClient.getDetectorProfile(profileRequest, profileListener); + } + + List profileResponses = new ArrayList<>(); + try { + CompletableFuture> listFuture = CompletableFuture + .allOf(profileFutures.toArray(new CompletableFuture[0])) + .thenApply(v -> profileFutures.stream().map(CompletableFuture::join).collect(Collectors.toList())); + profileResponses = listFuture.join(); + } catch (Exception e) { + log.error("Failed to get all anomaly detector profiles.", e); + listener.onFailure(e); + } + + for (GetAnomalyDetectorResponse profileResponse : profileResponses) { + if (profileResponse != null && profileResponse.getDetector() != null) { + String responseDetectorName = profileResponse.getDetector().getName(); + + // We follow the existing logic as the frontend to determine overall detector state + // https://github.com/opensearch-project/anomaly-detection-dashboards-plugin/blob/main/server/routes/utils/adHelpers.ts#L437 + String detectorState = DetectorStateString.Disabled.name(); + ADTask realtimeTask = profileResponse.getRealtimeAdTask(); + + if (realtimeTask != null) { + String taskState = realtimeTask.getState(); + if (taskState.equalsIgnoreCase("CREATED") || taskState.equalsIgnoreCase("RUNNING")) { + detectorState = DetectorStateString.Running.name(); + } else if (taskState.equalsIgnoreCase("INIT_FAILURE") + || taskState.equalsIgnoreCase("UNEXPECTED_FAILURE") + || taskState.equalsIgnoreCase("FAILED")) { + detectorState = DetectorStateString.Failed.name(); + } + } + + boolean includeRunning = running != null && running == true; + boolean includeFailed = failed != null && failed == true; + boolean isValid = true; + + if (detectorState.equals(DetectorStateString.Running.name())) { + isValid = (running == null || running == true) && !(includeFailed && running == null); + } else if (detectorState.equals(DetectorStateString.Failed.name())) { + isValid = (failed == null || failed == true) && !(includeRunning && failed == null); + } else if (detectorState.equals(DetectorStateString.Disabled.name())) { + isValid = (running == null || running == false) && !(includeFailed && running == null); + } + + if (!isValid) { + hitsAsMap.remove(responseDetectorName); + } + } + } + } + + processHits(hitsAsMap, listener); + }, e -> { + // System index isn't initialized by default, so ignore such errors + if (e instanceof IndexNotFoundException) { + processHits(Collections.emptyMap(), listener); + } else { + log.error("Failed to search anomaly detectors.", e); + listener.onFailure(e); + } + + }); + + adClient.searchAnomalyDetectors(searchDetectorRequest, searchDetectorListener); + } + + @Override + public boolean validate(Map parameters) { + return true; + } + + @Override + public String getType() { + return TYPE; + } + + private void processHits(Map hitsAsMap, ActionListener listener) { + StringBuilder sb = new StringBuilder(); + sb.append("AnomalyDetectors=["); + for (SearchHit hit : hitsAsMap.values()) { + sb.append("{"); + sb.append("id=").append(hit.getId()).append(","); + sb.append("name=").append(hit.getSourceAsMap().get("name")).append(","); + sb.append("type=").append(hit.getSourceAsMap().get("detector_type")).append(","); + sb.append("description=").append(hit.getSourceAsMap().get("description")).append(","); + sb.append("index=").append(hit.getSourceAsMap().get("indices")).append(","); + sb.append("lastUpdateTime=").append(hit.getSourceAsMap().get("last_update_time")); + sb.append("}"); + } + sb.append("]"); + sb.append("TotalAnomalyDetectors=").append(hitsAsMap.size()); + listener.onResponse((T) sb.toString()); + } + + /** + * Factory for the {@link SearchAnomalyDetectorsTool} + */ + public static class Factory implements Tool.Factory { + private Client client; + + private NamedWriteableRegistry namedWriteableRegistry; + + private AnomalyDetectionNodeClient adClient; + + private static Factory INSTANCE; + + /** + * Create or return the singleton factory instance + */ + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (SearchAnomalyDetectorsTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + /** + * Initialize this factory + * @param client The OpenSearch client + */ + public void init(Client client, NamedWriteableRegistry namedWriteableRegistry) { + this.client = client; + this.namedWriteableRegistry = namedWriteableRegistry; + this.adClient = new AnomalyDetectionNodeClient(client, namedWriteableRegistry); + } + + @Override + public SearchAnomalyDetectorsTool create(Map map) { + return new SearchAnomalyDetectorsTool(client, namedWriteableRegistry); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } + } + +} diff --git a/src/main/java/org/opensearch/agent/tools/SearchAnomalyResultsTool.java b/src/main/java/org/opensearch/agent/tools/SearchAnomalyResultsTool.java new file mode 100644 index 00000000..a2973d6b --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/SearchAnomalyResultsTool.java @@ -0,0 +1,247 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.client.AnomalyDetectionNodeClient; +import org.opensearch.agent.tools.utils.ToolConstants; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.ExistsQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.SortOrder; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@ToolAnnotation(SearchAnomalyResultsTool.TYPE) +public class SearchAnomalyResultsTool implements Tool { + public static final String TYPE = "SearchAnomalyResultsTool"; + private static final String DEFAULT_DESCRIPTION = + "This is a tool that searches anomaly results. It takes 9 arguments named detectorId which defines the detector ID to filter for (default is null), and realtime which defines whether the anomaly results are from a realtime detector (set to false to only get results from historical analyses) (default is null), and anomalyGradeThreshold which defines the threshold for anomaly grade (a number between 0 and 1 that indicates how anomalous a data point is) (default is greater than 0), and dataStartTime which defines the start time of the anomaly data in epoch milliseconds (default is null), and dataEndTime which defines the end time of the anomaly data in epoch milliseconds (default is null), and sortOrder which defines the order of the results (options are asc or desc, and default is desc), and sortString which defines how to sort the results (default is data_start_time), and size which defines the number of anomalies to be returned (default is 20), and startIndex which defines the paginated index to start from (default is 0). The tool returns 2 values: a list of anomaly results (where each result contains the detector ID, the anomaly grade, and the confidence), and the total number of anomaly results."; + + @Setter + @Getter + private String name = TYPE; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + + @Getter + private String version; + + private Client client; + + private AnomalyDetectionNodeClient adClient; + + @Setter + private Parser inputParser; + @Setter + private Parser outputParser; + + public SearchAnomalyResultsTool(Client client, NamedWriteableRegistry namedWriteableRegistry) { + this.client = client; + this.adClient = new AnomalyDetectionNodeClient(client, namedWriteableRegistry); + + // probably keep this overridden output parser. need to ensure the output matches what's expected + outputParser = new Parser<>() { + @Override + public Object parse(Object o) { + @SuppressWarnings("unchecked") + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + } + }; + } + + // Response is currently in a simple string format including the list of anomaly results (only detector ID, grade, confidence), + // and total # of results. The output will likely need to be updated, standardized, and include more fields in the + // future to cover a sufficient amount of potential questions the agent will need to handle. + @Override + public void run(Map parameters, ActionListener listener) { + final String detectorId = parameters.getOrDefault("detectorId", null); + final Boolean realTime = parameters.containsKey("realTime") ? Boolean.parseBoolean(parameters.get("realTime")) : null; + final Double anomalyGradeThreshold = parameters.containsKey("anomalyGradeThreshold") + ? Double.parseDouble(parameters.get("anomalyGradeThreshold")) + : 0; + final Long dataStartTime = parameters.containsKey("dataStartTime") && StringUtils.isNumeric(parameters.get("dataStartTime")) + ? Long.parseLong(parameters.get("dataStartTime")) + : null; + final Long dataEndTime = parameters.containsKey("dataEndTime") && StringUtils.isNumeric(parameters.get("dataEndTime")) + ? Long.parseLong(parameters.get("dataEndTime")) + : null; + final String sortOrderStr = parameters.getOrDefault("sortOrder", "asc"); + final SortOrder sortOrder = sortOrderStr.equalsIgnoreCase("asc") ? SortOrder.ASC : SortOrder.DESC; + final String sortString = parameters.getOrDefault("sortString", "data_start_time"); + final int size = parameters.containsKey("size") ? Integer.parseInt(parameters.get("size")) : 20; + final int startIndex = parameters.containsKey("startIndex") ? Integer.parseInt(parameters.get("startIndex")) : 0; + + List mustList = new ArrayList(); + if (detectorId != null) { + mustList.add(new TermQueryBuilder("detector_id", detectorId)); + } + // We include or exclude the task ID if fetching historical or real-time results, respectively. + // For more details, see https://opensearch.org/docs/latest/observing-your-data/ad/api/#search-detector-result + if (realTime != null) { + BoolQueryBuilder boolQuery = new BoolQueryBuilder(); + ExistsQueryBuilder existsQuery = new ExistsQueryBuilder("task_id"); + if (realTime) { + boolQuery.mustNot(existsQuery); + } else { + boolQuery.must(existsQuery); + } + mustList.add(boolQuery); + } + if (anomalyGradeThreshold != null) { + mustList.add(new RangeQueryBuilder("anomaly_grade").gt(anomalyGradeThreshold)); + } + if (dataStartTime != null || dataEndTime != null) { + RangeQueryBuilder rangeQuery = new RangeQueryBuilder("anomaly_grade"); + if (dataStartTime != null) { + rangeQuery.gte(dataStartTime); + } + if (dataEndTime != null) { + rangeQuery.lte(dataEndTime); + } + mustList.add(rangeQuery); + } + + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.must().addAll(mustList); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(boolQueryBuilder) + .size(size) + .from(startIndex) + .sort(sortString, sortOrder); + + // In the future we may support custom result indices. For now default to the default AD result system indices. + SearchRequest searchAnomalyResultsRequest = new SearchRequest() + .source(searchSourceBuilder) + .indices(ToolConstants.AD_RESULTS_INDEX_PATTERN); + + ActionListener searchAnomalyResultsListener = ActionListener.wrap(response -> { + processHits(response.getHits(), listener); + }, e -> { + // System index isn't initialized by default, so ignore such errors + if (e instanceof IndexNotFoundException) { + processHits(SearchHits.empty(), listener); + } else { + log.error("Failed to search anomaly results.", e); + listener.onFailure(e); + + } + }); + + adClient.searchAnomalyResults(searchAnomalyResultsRequest, searchAnomalyResultsListener); + } + + @Override + public boolean validate(Map parameters) { + return true; + } + + @Override + public String getType() { + return TYPE; + } + + private void processHits(SearchHits searchHits, ActionListener listener) { + SearchHit[] hits = searchHits.getHits(); + + StringBuilder sb = new StringBuilder(); + sb.append("AnomalyResults=["); + for (SearchHit hit : hits) { + sb.append("{"); + sb.append("detectorId=").append(hit.getSourceAsMap().get("detector_id")).append(","); + sb.append("grade=").append(hit.getSourceAsMap().get("anomaly_grade")).append(","); + sb.append("confidence=").append(hit.getSourceAsMap().get("confidence")); + sb.append("}"); + } + sb.append("]"); + sb.append("TotalAnomalyResults=").append(searchHits.getTotalHits().value); + listener.onResponse((T) sb.toString()); + } + + /** + * Factory for the {@link SearchAnomalyResultsTool} + */ + public static class Factory implements Tool.Factory { + private Client client; + + private NamedWriteableRegistry namedWriteableRegistry; + + private AnomalyDetectionNodeClient adClient; + + private static Factory INSTANCE; + + /** + * Create or return the singleton factory instance + */ + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (SearchAnomalyResultsTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + /** + * Initialize this factory + * @param client The OpenSearch client + */ + public void init(Client client, NamedWriteableRegistry namedWriteableRegistry) { + this.client = client; + this.namedWriteableRegistry = namedWriteableRegistry; + this.adClient = new AnomalyDetectionNodeClient(client, namedWriteableRegistry); + } + + @Override + public SearchAnomalyResultsTool create(Map map) { + return new SearchAnomalyResultsTool(client, namedWriteableRegistry); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } + } + +} diff --git a/src/main/java/org/opensearch/agent/tools/SearchMonitorsTool.java b/src/main/java/org/opensearch/agent/tools/SearchMonitorsTool.java new file mode 100644 index 00000000..433994cf --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/SearchMonitorsTool.java @@ -0,0 +1,252 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.commons.lang3.StringUtils; +import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.client.node.NodeClient; +import org.opensearch.commons.alerting.AlertingPluginInterface; +import org.opensearch.commons.alerting.action.SearchMonitorRequest; +import org.opensearch.commons.alerting.model.ScheduledJob; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.ExistsQueryBuilder; +import org.opensearch.index.query.NestedQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.query.WildcardQueryBuilder; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.SortOrder; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@ToolAnnotation(SearchMonitorsTool.TYPE) +public class SearchMonitorsTool implements Tool { + public static final String TYPE = "SearchMonitorsTool"; + private static final String DEFAULT_DESCRIPTION = + "This is a tool that searches alerting monitors. It takes 10 optional arguments named monitorId which defines the monitor ID to filter for (default is null), and monitorName which defines explicit name of the monitor (default is null), and monitorNamePattern which is a wildcard query to match monitor name (default is null), and enabled which defines whether the monitor is enabled (default is null, indicating both enabled and disabled), and hasTriggers which defines whether the monitor has triggers enabled (default is null, indicating both), and indices which defines the index being monitored (default is null), and sortOrder which defines the order of the results (options are asc or desc, and default is asc), and sortString which defines how to sort the results (default is name.keyword), and size which defines the size of the request to be returned (default is 20), and startIndex which defines the paginated index to start from (default is 0). The tool returns 2 values: a list of alerting monitors (each containining monitor ID, monitor name, monitor type (indicating query-level, document-level, or bucket-level monitor types), enabled, enabled time in epoch milliseconds, last update time in epoch milliseconds), and the total number of alerting monitors."; + @Setter + @Getter + private String name = TYPE; + @Getter + @Setter + private String description = DEFAULT_DESCRIPTION; + @Getter + private String type; + @Getter + private String version; + + private Client client; + @Setter + private Parser inputParser; + @Setter + private Parser outputParser; + + public SearchMonitorsTool(Client client) { + this.client = client; + + // probably keep this overridden output parser. need to ensure the output matches what's expected + outputParser = new Parser<>() { + @Override + public Object parse(Object o) { + @SuppressWarnings("unchecked") + List mlModelOutputs = (List) o; + return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); + } + }; + } + + // Response is currently in a simple string format including the list of monitors (only name and ID attached), and + // number of total monitors. The output will likely need to be updated, standardized, and include more fields in the + // future to cover a sufficient amount of potential questions the agent will need to handle. + @Override + public void run(Map parameters, ActionListener listener) { + final String monitorId = parameters.getOrDefault("monitorId", null); + final String monitorName = parameters.getOrDefault("monitorName", null); + final String monitorNamePattern = parameters.getOrDefault("monitorNamePattern", null); + final Boolean enabled = parameters.containsKey("enabled") ? Boolean.parseBoolean(parameters.get("enabled")) : null; + final Boolean hasTriggers = parameters.containsKey("hasTriggers") ? Boolean.parseBoolean(parameters.get("hasTriggers")) : null; + final String indices = parameters.getOrDefault("indices", null); + final String sortOrderStr = parameters.getOrDefault("sortOrder", "asc"); + final SortOrder sortOrder = "asc".equalsIgnoreCase(sortOrderStr) ? SortOrder.ASC : SortOrder.DESC; + final String sortString = parameters.getOrDefault("sortString", "monitor.name.keyword"); + final int size = parameters.containsKey("size") && StringUtils.isNumeric(parameters.get("size")) + ? Integer.parseInt(parameters.get("size")) + : 20; + final int startIndex = parameters.containsKey("startIndex") && StringUtils.isNumeric(parameters.get("startIndex")) + ? Integer.parseInt(parameters.get("startIndex")) + : 0; + + List mustList = new ArrayList(); + if (monitorId != null) { + mustList.add(new TermQueryBuilder("_id", monitorId)); + } + if (monitorName != null) { + mustList.add(new TermQueryBuilder("monitor.name.keyword", monitorName)); + } + if (monitorNamePattern != null) { + mustList.add(new WildcardQueryBuilder("monitor.name.keyword", monitorNamePattern)); + } + if (enabled != null) { + mustList.add(new TermQueryBuilder("monitor.enabled", enabled)); + } + if (hasTriggers != null) { + NestedQueryBuilder nestedTriggerQuery = new NestedQueryBuilder( + "monitor.triggers", + new ExistsQueryBuilder("monitor.triggers"), + ScoreMode.None + ); + + BoolQueryBuilder triggerQuery = new BoolQueryBuilder(); + if (hasTriggers) { + triggerQuery.must(nestedTriggerQuery); + } else { + triggerQuery.mustNot(nestedTriggerQuery); + } + mustList.add(triggerQuery); + } + if (indices != null) { + mustList + .add( + new NestedQueryBuilder( + "monitor.inputs", + new WildcardQueryBuilder("monitor.inputs.search.indices", indices), + ScoreMode.None + ) + ); + } + + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.must().addAll(mustList); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(boolQueryBuilder) + .size(size) + .from(startIndex) + .sort(sortString, sortOrder); + + SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(ScheduledJob.SCHEDULED_JOBS_INDEX); + SearchMonitorRequest searchMonitorRequest = new SearchMonitorRequest(searchRequest); + + ActionListener searchMonitorListener = ActionListener.wrap(response -> { + List hits = Arrays.asList(response.getHits().getHits()); + Map hitsAsMap = hits.stream().collect(Collectors.toMap(SearchHit::getId, hit -> hit)); + processHits(hitsAsMap, listener); + + }, e -> { + // System index isn't initialized by default, so ignore such errors. Alerting plugin does not return the + // standard IndexNotFoundException so we parse the message instead + if (e.getMessage().contains("Configured indices are not found")) { + processHits(Collections.emptyMap(), listener); + } else { + log.error("Failed to search monitors.", e); + listener.onFailure(e); + } + }); + AlertingPluginInterface.INSTANCE.searchMonitors((NodeClient) client, searchMonitorRequest, searchMonitorListener); + + } + + @Override + public boolean validate(Map parameters) { + return true; + } + + @Override + public String getType() { + return TYPE; + } + + private void processHits(Map hitsAsMap, ActionListener listener) { + StringBuilder sb = new StringBuilder(); + sb.append("Monitors=["); + for (SearchHit hit : hitsAsMap.values()) { + Map monitorAsMap = (Map) hit.getSourceAsMap().get("monitor"); + sb.append("{"); + sb.append("id=").append(hit.getId()).append(","); + sb.append("name=").append(monitorAsMap.get("name")).append(","); + sb.append("type=").append(monitorAsMap.get("monitor_type")).append(","); + sb.append("enabled=").append(monitorAsMap.get("enabled")).append(","); + sb.append("enabledTime=").append(monitorAsMap.get("enabled_time")).append(","); + sb.append("lastUpdateTime=").append(monitorAsMap.get("last_update_time")); + sb.append("}"); + } + sb.append("]"); + sb.append("TotalMonitors=").append(hitsAsMap.size()); + listener.onResponse((T) sb.toString()); + } + + /** + * Factory for the {@link SearchMonitorsTool} + */ + public static class Factory implements Tool.Factory { + private Client client; + + private static Factory INSTANCE; + + /** + * Create or return the singleton factory instance + */ + public static Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (SearchMonitorsTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new Factory(); + return INSTANCE; + } + } + + /** + * Initialize this factory + * @param client The OpenSearch client + */ + public void init(Client client) { + this.client = client; + } + + @Override + public SearchMonitorsTool create(Map map) { + return new SearchMonitorsTool(client); + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } + } + +} diff --git a/src/main/java/org/opensearch/agent/tools/VectorDBTool.java b/src/main/java/org/opensearch/agent/tools/VectorDBTool.java new file mode 100644 index 00000000..d397060e --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/VectorDBTool.java @@ -0,0 +1,167 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.Map; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.client.Client; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.spi.tools.ToolAnnotation; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +/** + * This tool supports neural search with embedding models and knn index. + */ +@Log4j2 +@Getter +@Setter +@ToolAnnotation(VectorDBTool.TYPE) +public class VectorDBTool extends AbstractRetrieverTool { + public static final String TYPE = "VectorDBTool"; + + public static String DEFAULT_DESCRIPTION = + "Use this tool to performs knn-based dense retrieval. It takes 1 argument named input which is a string query for dense retrieval. The tool returns the dense retrieval results for the query."; + public static final String MODEL_ID_FIELD = "model_id"; + public static final String EMBEDDING_FIELD = "embedding_field"; + public static final String K_FIELD = "k"; + public static final Integer DEFAULT_K = 10; + public static final String NESTED_PATH_FIELD = "nested_path"; + + private String name = TYPE; + private String modelId; + private String embeddingField; + private Integer k; + private String nestedPath; + + @Builder + public VectorDBTool( + Client client, + NamedXContentRegistry xContentRegistry, + String index, + String embeddingField, + String[] sourceFields, + Integer docSize, + String modelId, + Integer k, + String nestedPath + ) { + super(client, xContentRegistry, index, sourceFields, docSize); + this.modelId = modelId; + this.embeddingField = embeddingField; + this.k = k; + this.nestedPath = nestedPath; + } + + @Override + protected String getQueryBody(String queryText) { + if (StringUtils.isBlank(embeddingField) || StringUtils.isBlank(modelId)) { + throw new IllegalArgumentException( + "Parameter [" + EMBEDDING_FIELD + "] and [" + MODEL_ID_FIELD + "] can not be null or empty." + ); + } + + Map queryBody; + if (StringUtils.isBlank(nestedPath)) { + queryBody = Map + .of("query", Map.of("neural", Map.of(embeddingField, Map.of("query_text", queryText, "model_id", modelId, "k", k)))); + + } else { + queryBody = Map + .of( + "query", + Map + .of( + "nested", + Map + .of( + "path", + nestedPath, + "score_mode", + "max", + "query", + Map.of("neural", Map.of(embeddingField, Map.of("query_text", queryText, "model_id", modelId, "k", k))) + ) + ) + ); + } + + try { + return AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(queryBody)); + } catch (PrivilegedActionException e) { + throw new RuntimeException(e); + } + } + + @Override + public String getType() { + return TYPE; + } + + public static class Factory extends AbstractRetrieverTool.Factory { + private static VectorDBTool.Factory INSTANCE; + + public static VectorDBTool.Factory getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (VectorDBTool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new VectorDBTool.Factory(); + return INSTANCE; + } + } + + @Override + public VectorDBTool create(Map params) { + String index = (String) params.get(INDEX_FIELD); + String embeddingField = (String) params.get(EMBEDDING_FIELD); + String[] sourceFields = gson.fromJson((String) params.get(SOURCE_FIELD), String[].class); + String modelId = (String) params.get(MODEL_ID_FIELD); + Integer docSize = params.containsKey(DOC_SIZE_FIELD) ? Integer.parseInt((String) params.get(DOC_SIZE_FIELD)) : DEFAULT_DOC_SIZE; + Integer k = params.containsKey(K_FIELD) ? Integer.parseInt((String) params.get(K_FIELD)) : DEFAULT_K; + String nestedPath = (String) params.get(NESTED_PATH_FIELD); + return VectorDBTool + .builder() + .client(client) + .xContentRegistry(xContentRegistry) + .index(index) + .embeddingField(embeddingField) + .sourceFields(sourceFields) + .modelId(modelId) + .docSize(docSize) + .k(k) + .nestedPath(nestedPath) + .build(); + } + + @Override + public String getDefaultType() { + return TYPE; + } + + @Override + public String getDefaultVersion() { + return null; + } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } + } +} diff --git a/src/main/java/org/opensearch/agent/tools/utils/ToolConstants.java b/src/main/java/org/opensearch/agent/tools/utils/ToolConstants.java new file mode 100644 index 00000000..2a90ec7e --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/utils/ToolConstants.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools.utils; + +public class ToolConstants { + // Detector state is not cleanly defined on the backend plugin. So, we persist a standard + // set of states here for users to interface with when fetching and filtering detectors. + // This follows what frontend AD users are familiar with, as we use the same parsing logic + // in SearchAnomalyDetectorsTool. + public static enum DetectorStateString { + Running, + Disabled, + Failed, + Initializing + } + + // System indices constants are not cleanly exposed from the AD & Alerting plugins, so we persist our + // own constants here. + public static final String AD_RESULTS_INDEX_PATTERN = ".opendistro-anomaly-results*"; + public static final String AD_RESULTS_INDEX = ".opendistro-anomaly-results"; + public static final String AD_DETECTORS_INDEX = ".opendistro-anomaly-detectors"; + + public static final String ALERTING_CONFIG_INDEX = ".opendistro-alerting-config"; + public static final String ALERTING_ALERTS_INDEX = ".opendistro-alerting-alerts"; + +} diff --git a/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java b/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java new file mode 100644 index 00000000..d7f6c3f5 --- /dev/null +++ b/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools.utils; + +import java.util.Map; + +public class ToolHelper { + /** + * Flatten all the fields in the mappings, insert the field to fieldType mapping to a map + * @param mappingSource the mappings of an index + * @param fieldsToType the result containing the field to fieldType mapping + * @param prefix the parent field path + * @param includeFields whether include the `fields` in a text type field, for some use case like PPLTool, `fields` in a text type field + * cannot be included, but for CreateAnomalyDetectorTool, `fields` must be included. + */ + public static void extractFieldNamesTypes( + Map mappingSource, + Map fieldsToType, + String prefix, + boolean includeFields + ) { + if (prefix.length() > 0) { + prefix += "."; + } + + for (Map.Entry entry : mappingSource.entrySet()) { + String n = entry.getKey(); + Object v = entry.getValue(); + + if (v instanceof Map) { + Map vMap = (Map) v; + if (vMap.containsKey("type")) { + String fieldType = (String) vMap.getOrDefault("type", ""); + // no need to extract alias into the result, and for object field, extract the subfields only + if (!fieldType.equals("alias") && !fieldType.equals("object")) { + fieldsToType.put(prefix + n, (String) vMap.get("type")); + } + } + if (vMap.containsKey("properties")) { + extractFieldNamesTypes((Map) vMap.get("properties"), fieldsToType, prefix + n, includeFields); + } + if (includeFields && vMap.containsKey("fields")) { + extractFieldNamesTypes((Map) vMap.get("fields"), fieldsToType, prefix + n, true); + } + } + } + } +} diff --git a/src/main/plugin-metadata/plugin-security.policy b/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 00000000..4c512a49 --- /dev/null +++ b/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,8 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +grant { + permission java.lang.RuntimePermission "accessDeclaredMembers"; +}; diff --git a/src/main/resources/org/opensearch/agent/tools/CreateAnomalyDetectorDefaultPrompt.json b/src/main/resources/org/opensearch/agent/tools/CreateAnomalyDetectorDefaultPrompt.json new file mode 100644 index 00000000..9b69bce7 --- /dev/null +++ b/src/main/resources/org/opensearch/agent/tools/CreateAnomalyDetectorDefaultPrompt.json @@ -0,0 +1,4 @@ +{ + "CLAUDE": "Human:\" turn\": Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"timestamp\",\"indices\":[\"server_log*\"],\"feature_attributes\":[{\"feature_name\":\"test\",\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}],\"category_field\":[\"ip\"]}, and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about creating an anomaly detector for the index ${indexInfo.indexName}, you need to give the key information: the top 3 suitable aggregation fields which are numeric types and the suitable aggregation method for each field, if there are no numeric type fields, both the aggregation field and method are empty string, and also give the category field if there exists a keyword type field like ip, address, host, city, country or region, if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}. \n\nAssistant:\" turn\"", + "OPENAI": "Here is an example of the create anomaly detector API: POST _plugins/_anomaly_detection/detectors, {\"time_field\":\"timestamp\",\"indices\":[\"server_log*\"],\"feature_attributes\":[{\"feature_name\":\"test\",\"feature_enabled\":true,\"aggregation_query\":{\"test\":{\"sum\":{\"field\":\"value\"}}}}],\"category_field\":[\"ip\"]}, and here are the mapping info containing all the fields in the index ${indexInfo.indexName}: ${indexInfo.indexMapping}, and the optional aggregation methods are count, avg, min, max and sum. Please give me some suggestion about creating an anomaly detector for the index ${indexInfo.indexName}, you need to give the key information: the top 3 suitable aggregation fields which are numeric types and the suitable aggregation method for each field, if there are no numeric type fields, both the aggregation field and method are empty string, and also give the category field if there exists a keyword type field like ip, address, host, city, country or region, if not exist, the category field is empty. Show me a format of keyed and pipe-delimited list wrapped in a curly bracket just like {category_field=the category field if exists|aggregation_field=comma-delimited list of all the aggregation field names|aggregation_method=comma-delimited list of all the aggregation methods}. " +} diff --git a/src/main/resources/org/opensearch/agent/tools/PPLDefaultPrompt.json b/src/main/resources/org/opensearch/agent/tools/PPLDefaultPrompt.json new file mode 100644 index 00000000..98909cb3 --- /dev/null +++ b/src/main/resources/org/opensearch/agent/tools/PPLDefaultPrompt.json @@ -0,0 +1,5 @@ +{ + "CLAUDE": "\n\nHuman:You will be given a question about some metrics from a user.\nUse context provided to write a PPL query that can be used to retrieve the information.\n\nHere is a sample PPL query:\nsource=`` | where `` = '``'\n\nHere are some sample questions and the PPL query to retrieve the information. The format for fields is\n```\n- field_name: field_type (sample field value)\n```\n\nFor example, below is a field called `timestamp`, it has a field type of `date`, and a sample value of it could look like `1686000665919`.\n```\n- timestamp: date (1686000665919)\n```\n----------------\n\nThe following text contains fields and questions/answers for the 'accounts' index\n\nFields:\n- account_number: long (101)\n- address: text ('880 Holmes Lane')\n- age: long (32)\n- balance: long (39225)\n- city: text ('Brogan')\n- email: text ('amberduke@pyrami.com')\n- employer: text ('Pyrami')\n- firstname: text ('Amber')\n- gender: text ('M')\n- lastname: text ('Duke')\n- state: text ('IL')\n- registered_at: date (1686000665919)\n\nQuestion: Give me some documents in index 'accounts'\nPPL: source=`accounts` | head\n\nQuestion: Give me 5 oldest people in index 'accounts'\nPPL: source=`accounts` | sort -age | head 5\n\nQuestion: Give me first names of 5 youngest people in index 'accounts'\nPPL: source=`accounts` | sort +age | head 5 | fields `firstname`\n\nQuestion: Give me some addresses in index 'accounts'\nPPL: source=`accounts` | fields `address`\n\nQuestion: Find the documents in index 'accounts' where firstname is 'Hattie'\nPPL: source=`accounts` | where `firstname` = 'Hattie'\n\nQuestion: Find the emails where firstname is 'Hattie' or lastname is 'Frank' in index 'accounts'\nPPL: source=`accounts` | where `firstname` = 'Hattie' OR `lastname` = 'frank' | fields `email`\n\nQuestion: Find the documents in index 'accounts' where firstname is not 'Hattie' and lastname is not 'Frank'\nPPL: source=`accounts` | where `firstname` != 'Hattie' AND `lastname` != 'frank'\n\nQuestion: Find the emails that contain '.com' in index 'accounts'\nPPL: source=`accounts` | where QUERY_STRING(['email'], '.com') | fields `email`\n\nQuestion: Find the documents in index 'accounts' where there is an email\nPPL: source=`accounts` | where ISNOTNULL(`email`)\n\nQuestion: Count the number of documents in index 'accounts'\nPPL: source=`accounts` | stats COUNT() AS `count`\n\nQuestion: Count the number of people with firstnaQuestion: Count the number of people withe=`accounts` | where `firstname` ='Amber' | stats COUNT() AS `count`\n\nQuestion: How many people are older than 33? index is 'accounts'\nPPL: source=`accounts` | where `age` > 33 | stats COUNT() AS `count`\n\nQuestion: How many distinct ages? index is 'accounts'\nPPL: source=`accounts` | stats DISTINCT_COUNT(age) AS `distinct_count`\n\nQuestion: How many males and females in index 'accounts'?\nPPL: source=`accounts` | stats COUNT() AS `count` BY `gender`\n\nQuestion: What is the average, minimum, maximum age in 'accounts' index?\nPPL: source=`accounts` | stats AVG(`age`) AS `avg_age`, MIN(`age`) AS `min_age`, MAX(`age`) AS `max_age`\n\nQuestion: Show all states sorted by average balance. index is 'accounts'\nPPL: source=`accounts` | stats AVG(`balance`) AS `avg_balance` BY `state` | sort +avg_balance\n\n----------------\n\nThe following text contains fields and questions/answers for the 'ecommerce' index\n\nFields:\n- category: text ('Men's Clothing')\n- currency: keyword ('EUR')\n- customer_birth_date: date (null)\n- customer_first_name: text ('Eddie')\n- customer_full_name: text ('Eddie Underwood')\n- customer_gender: keyword ('MALE')\n- customer_id: keyword ('38')\n- customer_last_name: text ('Underwood')\n- customer_phone: keyword ('')\n- day_of_week: keyword ('Monday')\n- day_of_week_i: integer (0)\n- email: keyword ('eddie@underwood-family.zzz')\n- event.dataset: keyword ('sample_ecommerce')\n- geoip.city_name: keyword ('Cairo')\n- geoip.continent_name: keyword ('Africa')\n- geoip.country_iso_code: keyword ('EG')\n- geoip.location: geo_point ([object Object])\n- geoip.region_name: keyword ('Cairo Governorate')\n- manufacturer: text ('Elitelligence,Oceanavigations')\n- order_date: date (2023-06-05T09:28:48+00:00)\n- order_id: keyword ('584677')\n- products._id: text (null)\n- products.base_price: half_float (null)\n- products.base_unit_price: half_float (null)\n- products.category: text (null)\n- products.created_on: date (null)\n- products.discount_amount: half_float (null)\n- products.discount_percentage: half_float (null)\n- products.manufacturer: text (null)\n- products.min_price: half_float (null)\n- products.price: half_float (null)\n- products.product_id: long (null)\n- products.product_name: text (null)\n- products.quantity: integer (null)\n- products.sku: keyword (null)\n- products.tax_amount: half_float (null)\n- products.taxful_price: half_float (null)\n- products.taxless_price: half_float (null)\n- products.unit_discount_amount: half_float (null)\n- sku: keyword ('ZO0549605496,ZO0299602996')\n- taxful_total_price: half_float (36.98)\n- taxless_total_price: half_float (36.98)\n- total_quantity: integer (2)\n- total_unique_products: integer (2)\n- type: keyword ('order')\n- user: keyword ('eddie')\n\nQuestion: What is the average price of products in clothing category ordered in the last 7 days? index is 'ecommerce'\nPPL: source=`ecommerce` | where QUERY_STRING(['category'], 'clothing') AND `order_date` > DATE_SUB(NOW(), INTERVAL 7 DAY) | stats AVG(`taxful_total_price`) AS `avg_price`\n\nQuestion: What is the average price of products in each city ordered today by every 2 hours? index is 'ecommerce'\nPPL: source=`ecommerce` | where `order_date` > DATE_SUB(NOW(), INTERVAL 24 HOUR) | stats AVG(`taxful_total_price`) AS `avg_price` by SPAN(`order_date`, 2h) AS `span`, `geoip.city_name`\n\nQuestion: What is the total revenue of shoes each day in this week? index is 'ecommerce'\nPPL: source=`ecommerce` | where QUERY_STRING(['category'], 'shoes') AND `order_date` > DATE_SUB(NOW(), INTERVAL 1 WEEK) | stats SUM(`taxful_total_price`) AS `revenue` by SPAN(`order_date`, 1d) AS `span`\n\n----------------\n\nThe following text contains fields and questions/answers for the 'events' index\nFields:\n- timestamp: long (1686000665919)\n- attributes.data_stream.dataset: text ('nginx.access')\n- attributes.data_stream.namespace: text ('production')\n- attributes.data_stream.type: text ('logs')\n- body: text ('172.24.0.1 - - [02/Jun/2023:23:09:27 +0000] 'GET / HTTP/1.1' 200 4955 '-' 'Mozilla/5.0 zgrab/0.x'')\n- communication.source.address: text ('127.0.0.1')\n- communication.source.ip: text ('172.24.0.1')\n- container_id: text (null)\n- container_name: text (null)\n- event.category: text ('web')\n- event.domain: text ('nginx.access')\n- event.kind: text ('event')\n- event.name: text ('access')\n- event.result: text ('success')\n- event.type: text ('access')\n- http.flavor: text ('1.1')\n- http.request.method: text ('GET')\n- http.response.bytes: long (4955)\n- http.response.status_code: keyword ('200')\n- http.url: text ('/')\n- log: text (null)\n- observerTime: date (1686000665919)\n- source: text (null)\n- span_id: text ('abcdef1010')\n- trace_id: text ('102981ABCD2901')\n\nQuestion: What are recent logs with errors and contains word 'test'? index is 'events'\nPPL: source=`events` | where QUERY_STRING(['http.response.status_code'], '4* OR 5*') AND QUERY_STRING(['body'], 'test') AND `observerTime` > DATE_SUB(NOW(), INTERVAL 5 MINUTE)\n\nQuestion: What is the total number of log with a status code other than 200 in 2023 Feburary? index is 'events'\nPPL: source=`events` | where QUERY_STRING(['http.response.status_code'], '!200') AND `observerTime` >= '2023-03-01 00:00:00' AND `observerTime` < '2023-04-01 00:00:00' | stats COUNT() AS `count`\n\nQuestion: Count the number of business days that have web category logs last week? index is 'events'\nPPL: source=`events` | where `category` = 'web' AND `observerTime` > DATE_SUB(NOW(), INTERVAL 1 WEEK) AND DAY_OF_WEEK(`observerTime`) >= 2 AND DAY_OF_WEEK(`observerTime`) <= 6 | stats DISTINCT_COUNT(DATE_FORMAT(`observerTime`, 'yyyy-MM-dd')) AS `distinct_count`\n\nQuestion: What are the top traces with largest bytes? index is 'events'\nPPL: source=`events` | stats SUM(`http.response.bytes`) AS `sum_bytes` by `trace_id` | sort -sum_bytes | head\n\nQuestion: Give me log patterns? index is 'events'\nPPL: source=`events` | patterns `body` | stats take(`body`, 1) AS `sample_pattern` by `patterns_field` | fields `sample_pattern`\n\nQuestion: Give me log patterns for logs with errors? index is 'events'\nPPL: source=`events` | where QUERY_STRING(['http.response.status_code'], '4* OR 5*') | patterns `body` | stats take(`body`, 1) AS `sample_pattern` by `patterns_field` | fields `sample_pattern`\n\n----------------\n\nUse the following steps to generate the PPL query:\n\nStep 1. Find all field entities in the question.\n\nStep 2. Pick the fields that are relevant to the question from the provided fields list using entities. Rules:\n#01 Consider the field name, the field type, and the sample value when picking relevant fields. For example, if you need to filter flights departed from 'JFK', look for a `text` or `keyword` field with a field name such as 'departedAirport', and the sample value should be a 3 letter IATA airport code. Similarly, if you need a date field, look for a relevant field name with type `date` and not `long`.\n#02 You must pick a field with `date` type when filtering on date/time.\n#03 You must pick a field with `date` type when aggregating by time interval.\n#04 You must not use the sample value in PPL query, unless it is relevant to the question.\n#05 You must only pick fields that are relevant, and must pick the whole field name from the fields list.\n#06 You must not use fields that are not in the fields list.\n#07 You must not use the sample values unless relevant to the question.\n#08 You must pick the field that contains a log line when asked about log patterns. Usually it is one of `log`, `body`, `message`.\n\nStep 3. Use the choosen fields to write the PPL query. Rules:\n#01 Always use comparisons to filter date/time, eg. 'where `timestamp` > DATE_SUB(NOW(), INTERVAL 1 DAY)'; or by absolute time: 'where `timestamp` > 'yyyy-MM-dd HH:mm:ss'', eg. 'where `timestamp` < '2023-01-01 00:00:00''. Do not use `DATE_FORMAT()`.\n#02 Only use PPL syntax and keywords appeared in the question or in the examples.\n#03 If user asks for current or recent status, filter the time field for last 5 minutes.\n#04 The field used in 'SPAN(``, )' must have type `date`, not `long`.\n#05 When aggregating by `SPAN` and another field, put `SPAN` after `by` and before the other field, eg. 'stats COUNT() AS `count` by SPAN(`timestamp`, 1d) AS `span`, `category`'.\n#06 You must put values in quotes when filtering fields with `text` or `keyword` field type.\n#07 To find documents that contain certain phrases in string fields, use `QUERY_STRING` which supports multiple fields and wildcard, eg. 'where QUERY_STRING(['field1', 'field2'], 'prefix*')'.\n#08 To find 4xx and 5xx errors using status code, if the status code field type is numberic (eg. `integer`), then use 'where `status_code` >= 400'; if the field is a string (eg. `text` or `keyword`), then use 'where QUERY_STRING(['status_code'], '4* OR 5*')'.\n\n----------------\nPut your PPL query in tags.\n----------------\nQuestion : ${indexInfo.question}? index is `${indexInfo.indexName}`\nFields:\n${indexInfo.mappingInfo}\n\nAssistant:", + "FINETUNE": "Below is an instruction that describes a task, paired with the index and corresponding fields that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nI have an opensearch index with fields in the following. Now I have a question: ${indexInfo.question}. Can you help me generate a PPL for that?\n\n### Index:\n${indexInfo.indexName}\n\n### Fields:\n${indexInfo.mappingInfo}\n\n### Response:\n", + "OPENAI": "You will be given a question about some metrics from a user.\nUse context provided to write a PPL query that can be used to retrieve the information.\n\nHere is a sample PPL query:\nsource=`` | where `` = '``'\n\nHere are some sample questions and the PPL query to retrieve the information. The format for fields is\n```\n- field_name: field_type (sample field value)\n```\n\nFor example, below is a field called `timestamp`, it has a field type of `date`, and a sample value of it could look like `1686000665919`.\n```\n- timestamp: date (1686000665919)\n```\n----------------\n\nThe following text contains fields and questions/answers for the 'accounts' index\n\nFields:\n- account_number: long (101)\n- address: text ('880 Holmes Lane')\n- age: long (32)\n- balance: long (39225)\n- city: text ('Brogan')\n- email: text ('amberduke@pyrami.com')\n- employer: text ('Pyrami')\n- firstname: text ('Amber')\n- gender: text ('M')\n- lastname: text ('Duke')\n- state: text ('IL')\n- registered_at: date (1686000665919)\n\nQuestion: Give me some documents in index 'accounts'\nPPL: source=`accounts` | head\n\nQuestion: Give me 5 oldest people in index 'accounts'\nPPL: source=`accounts` | sort -age | head 5\n\nQuestion: Give me first names of 5 youngest people in index 'accounts'\nPPL: source=`accounts` | sort +age | head 5 | fields `firstname`\n\nQuestion: Give me some addresses in index 'accounts'\nPPL: source=`accounts` | fields `address`\n\nQuestion: Find the documents in index 'accounts' where firstname is 'Hattie'\nPPL: source=`accounts` | where `firstname` = 'Hattie'\n\nQuestion: Find the emails where firstname is 'Hattie' or lastname is 'Frank' in index 'accounts'\nPPL: source=`accounts` | where `firstname` = 'Hattie' OR `lastname` = 'frank' | fields `email`\n\nQuestion: Find the documents in index 'accounts' where firstname is not 'Hattie' and lastname is not 'Frank'\nPPL: source=`accounts` | where `firstname` != 'Hattie' AND `lastname` != 'frank'\n\nQuestion: Find the emails that contain '.com' in index 'accounts'\nPPL: source=`accounts` | where QUERY_STRING(['email'], '.com') | fields `email`\n\nQuestion: Find the documents in index 'accounts' where there is an email\nPPL: source=`accounts` | where ISNOTNULL(`email`)\n\nQuestion: Count the number of documents in index 'accounts'\nPPL: source=`accounts` | stats COUNT() AS `count`\n\nQuestion: Count the number of people with firstnaQuestion: Count the number of people withe=`accounts` | where `firstname` ='Amber' | stats COUNT() AS `count`\n\nQuestion: How many people are older than 33? index is 'accounts'\nPPL: source=`accounts` | where `age` > 33 | stats COUNT() AS `count`\n\nQuestion: How many distinct ages? index is 'accounts'\nPPL: source=`accounts` | stats DISTINCT_COUNT(age) AS `distinct_count`\n\nQuestion: How many males and females in index 'accounts'?\nPPL: source=`accounts` | stats COUNT() AS `count` BY `gender`\n\nQuestion: What is the average, minimum, maximum age in 'accounts' index?\nPPL: source=`accounts` | stats AVG(`age`) AS `avg_age`, MIN(`age`) AS `min_age`, MAX(`age`) AS `max_age`\n\nQuestion: Show all states sorted by average balance. index is 'accounts'\nPPL: source=`accounts` | stats AVG(`balance`) AS `avg_balance` BY `state` | sort +avg_balance\n\n----------------\n\nThe following text contains fields and questions/answers for the 'ecommerce' index\n\nFields:\n- category: text ('Men's Clothing')\n- currency: keyword ('EUR')\n- customer_birth_date: date (null)\n- customer_first_name: text ('Eddie')\n- customer_full_name: text ('Eddie Underwood')\n- customer_gender: keyword ('MALE')\n- customer_id: keyword ('38')\n- customer_last_name: text ('Underwood')\n- customer_phone: keyword ('')\n- day_of_week: keyword ('Monday')\n- day_of_week_i: integer (0)\n- email: keyword ('eddie@underwood-family.zzz')\n- event.dataset: keyword ('sample_ecommerce')\n- geoip.city_name: keyword ('Cairo')\n- geoip.continent_name: keyword ('Africa')\n- geoip.country_iso_code: keyword ('EG')\n- geoip.location: geo_point ([object Object])\n- geoip.region_name: keyword ('Cairo Governorate')\n- manufacturer: text ('Elitelligence,Oceanavigations')\n- order_date: date (2023-06-05T09:28:48+00:00)\n- order_id: keyword ('584677')\n- products._id: text (null)\n- products.base_price: half_float (null)\n- products.base_unit_price: half_float (null)\n- products.category: text (null)\n- products.created_on: date (null)\n- products.discount_amount: half_float (null)\n- products.discount_percentage: half_float (null)\n- products.manufacturer: text (null)\n- products.min_price: half_float (null)\n- products.price: half_float (null)\n- products.product_id: long (null)\n- products.product_name: text (null)\n- products.quantity: integer (null)\n- products.sku: keyword (null)\n- products.tax_amount: half_float (null)\n- products.taxful_price: half_float (null)\n- products.taxless_price: half_float (null)\n- products.unit_discount_amount: half_float (null)\n- sku: keyword ('ZO0549605496,ZO0299602996')\n- taxful_total_price: half_float (36.98)\n- taxless_total_price: half_float (36.98)\n- total_quantity: integer (2)\n- total_unique_products: integer (2)\n- type: keyword ('order')\n- user: keyword ('eddie')\n\nQuestion: What is the average price of products in clothing category ordered in the last 7 days? index is 'ecommerce'\nPPL: source=`ecommerce` | where QUERY_STRING(['category'], 'clothing') AND `order_date` > DATE_SUB(NOW(), INTERVAL 7 DAY) | stats AVG(`taxful_total_price`) AS `avg_price`\n\nQuestion: What is the average price of products in each city ordered today by every 2 hours? index is 'ecommerce'\nPPL: source=`ecommerce` | where `order_date` > DATE_SUB(NOW(), INTERVAL 24 HOUR) | stats AVG(`taxful_total_price`) AS `avg_price` by SPAN(`order_date`, 2h) AS `span`, `geoip.city_name`\n\nQuestion: What is the total revenue of shoes each day in this week? index is 'ecommerce'\nPPL: source=`ecommerce` | where QUERY_STRING(['category'], 'shoes') AND `order_date` > DATE_SUB(NOW(), INTERVAL 1 WEEK) | stats SUM(`taxful_total_price`) AS `revenue` by SPAN(`order_date`, 1d) AS `span`\n\n----------------\n\nThe following text contains fields and questions/answers for the 'events' index\nFields:\n- timestamp: long (1686000665919)\n- attributes.data_stream.dataset: text ('nginx.access')\n- attributes.data_stream.namespace: text ('production')\n- attributes.data_stream.type: text ('logs')\n- body: text ('172.24.0.1 - - [02/Jun/2023:23:09:27 +0000] 'GET / HTTP/1.1' 200 4955 '-' 'Mozilla/5.0 zgrab/0.x'')\n- communication.source.address: text ('127.0.0.1')\n- communication.source.ip: text ('172.24.0.1')\n- container_id: text (null)\n- container_name: text (null)\n- event.category: text ('web')\n- event.domain: text ('nginx.access')\n- event.kind: text ('event')\n- event.name: text ('access')\n- event.result: text ('success')\n- event.type: text ('access')\n- http.flavor: text ('1.1')\n- http.request.method: text ('GET')\n- http.response.bytes: long (4955)\n- http.response.status_code: keyword ('200')\n- http.url: text ('/')\n- log: text (null)\n- observerTime: date (1686000665919)\n- source: text (null)\n- span_id: text ('abcdef1010')\n- trace_id: text ('102981ABCD2901')\n\nQuestion: What are recent logs with errors and contains word 'test'? index is 'events'\nPPL: source=`events` | where QUERY_STRING(['http.response.status_code'], '4* OR 5*') AND QUERY_STRING(['body'], 'test') AND `observerTime` > DATE_SUB(NOW(), INTERVAL 5 MINUTE)\n\nQuestion: What is the total number of log with a status code other than 200 in 2023 Feburary? index is 'events'\nPPL: source=`events` | where QUERY_STRING(['http.response.status_code'], '!200') AND `observerTime` >= '2023-03-01 00:00:00' AND `observerTime` < '2023-04-01 00:00:00' | stats COUNT() AS `count`\n\nQuestion: Count the number of business days that have web category logs last week? index is 'events'\nPPL: source=`events` | where `category` = 'web' AND `observerTime` > DATE_SUB(NOW(), INTERVAL 1 WEEK) AND DAY_OF_WEEK(`observerTime`) >= 2 AND DAY_OF_WEEK(`observerTime`) <= 6 | stats DISTINCT_COUNT(DATE_FORMAT(`observerTime`, 'yyyy-MM-dd')) AS `distinct_count`\n\nQuestion: What are the top traces with largest bytes? index is 'events'\nPPL: source=`events` | stats SUM(`http.response.bytes`) AS `sum_bytes` by `trace_id` | sort -sum_bytes | head\n\nQuestion: Give me log patterns? index is 'events'\nPPL: source=`events` | patterns `body` | stats take(`body`, 1) AS `sample_pattern` by `patterns_field` | fields `sample_pattern`\n\nQuestion: Give me log patterns for logs with errors? index is 'events'\nPPL: source=`events` | where QUERY_STRING(['http.response.status_code'], '4* OR 5*') | patterns `body` | stats take(`body`, 1) AS `sample_pattern` by `patterns_field` | fields `sample_pattern`\n\n----------------\n\nUse the following steps to generate the PPL query:\n\nStep 1. Find all field entities in the question.\n\nStep 2. Pick the fields that are relevant to the question from the provided fields list using entities. Rules:\n#01 Consider the field name, the field type, and the sample value when picking relevant fields. For example, if you need to filter flights departed from 'JFK', look for a `text` or `keyword` field with a field name such as 'departedAirport', and the sample value should be a 3 letter IATA airport code. Similarly, if you need a date field, look for a relevant field name with type `date` and not `long`.\n#02 You must pick a field with `date` type when filtering on date/time.\n#03 You must pick a field with `date` type when aggregating by time interval.\n#04 You must not use the sample value in PPL query, unless it is relevant to the question.\n#05 You must only pick fields that are relevant, and must pick the whole field name from the fields list.\n#06 You must not use fields that are not in the fields list.\n#07 You must not use the sample values unless relevant to the question.\n#08 You must pick the field that contains a log line when asked about log patterns. Usually it is one of `log`, `body`, `message`.\n\nStep 3. Use the choosen fields to write the PPL query. Rules:\n#01 Always use comparisons to filter date/time, eg. 'where `timestamp` > DATE_SUB(NOW(), INTERVAL 1 DAY)'; or by absolute time: 'where `timestamp` > 'yyyy-MM-dd HH:mm:ss'', eg. 'where `timestamp` < '2023-01-01 00:00:00''. Do not use `DATE_FORMAT()`.\n#02 Only use PPL syntax and keywords appeared in the question or in the examples.\n#03 If user asks for current or recent status, filter the time field for last 5 minutes.\n#04 The field used in 'SPAN(``, )' must have type `date`, not `long`.\n#05 When aggregating by `SPAN` and another field, put `SPAN` after `by` and before the other field, eg. 'stats COUNT() AS `count` by SPAN(`timestamp`, 1d) AS `span`, `category`'.\n#06 You must put values in quotes when filtering fields with `text` or `keyword` field type.\n#07 To find documents that contain certain phrases in string fields, use `QUERY_STRING` which supports multiple fields and wildcard, eg. 'where QUERY_STRING(['field1', 'field2'], 'prefix*')'.\n#08 To find 4xx and 5xx errors using status code, if the status code field type is numberic (eg. `integer`), then use 'where `status_code` >= 400'; if the field is a string (eg. `text` or `keyword`), then use 'where QUERY_STRING(['status_code'], '4* OR 5*')'.\n\n----------------\nOutput format: use xml tags to surround your PPL query, eg. source=index.\n----------------\nQuestion : ${indexInfo.question}? index is `${indexInfo.indexName}`\nFields:\n${indexInfo.mappingInfo}" +} diff --git a/src/test/java/org/opensearch/agent/TestHelpers.java b/src/test/java/org/opensearch/agent/TestHelpers.java new file mode 100644 index 00000000..422808de --- /dev/null +++ b/src/test/java/org/opensearch/agent/TestHelpers.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent; + +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; + +import org.apache.lucene.search.TotalHits; +import org.mockito.Mockito; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.transport.GetAnomalyDetectorResponse; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.Aggregations; + +public class TestHelpers { + + public static SearchResponse generateSearchResponse(SearchHit[] hits) { + TotalHits totalHits = new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO); + return new SearchResponse( + new SearchResponseSections(new SearchHits(hits, totalHits, 0), new Aggregations(new ArrayList<>()), null, false, null, null, 0), + null, + 0, + 0, + 0, + 0, + null, + null + ); + } + + public static GetAnomalyDetectorResponse generateGetAnomalyDetectorResponses(String[] detectorNames, String[] detectorStates) { + AnomalyDetector detector = Mockito.mock(AnomalyDetector.class); + // For each subsequent call to getId(), return the next detectorId in the array + when(detector.getName()).thenReturn(detectorNames[0], Arrays.copyOfRange(detectorNames, 1, detectorNames.length)); + ADTask realtimeAdTask = Mockito.mock(ADTask.class); + // For each subsequent call to getState(), return the next detectorState in the array + when(realtimeAdTask.getState()).thenReturn(detectorStates[0], Arrays.copyOfRange(detectorStates, 1, detectorStates.length)); + GetAnomalyDetectorResponse getDetectorProfileResponse = Mockito.mock(GetAnomalyDetectorResponse.class); + when(getDetectorProfileResponse.getRealtimeAdTask()).thenReturn(realtimeAdTask); + when(getDetectorProfileResponse.getDetector()).thenReturn(detector); + return getDetectorProfileResponse; + } + + public static SearchHit generateSearchDetectorHit(String detectorName, String detectorId) throws IOException { + XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent()); + content.startObject(); + content.field("name", detectorName); + content.endObject(); + return new SearchHit(0, detectorId, null, null).sourceRef(BytesReference.bytes(content)); + } +} diff --git a/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java b/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java new file mode 100644 index 00000000..04b5f473 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/AbstractRetrieverToolTests.java @@ -0,0 +1,246 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; +import static org.opensearch.agent.tools.AbstractRetrieverTool.DEFAULT_DESCRIPTION; + +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.search.SearchModule; + +import lombok.SneakyThrows; + +public class AbstractRetrieverToolTests { + static public final String TEST_QUERY = "{\"query\":{\"match_all\":{}}}"; + static public final String TEST_INDEX = "test index"; + static public final String[] TEST_SOURCE_FIELDS = new String[] { "test 1", "test 2" }; + static public final Integer TEST_DOC_SIZE = 3; + static public final NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_QUERY = new NamedXContentRegistry( + new SearchModule(Settings.EMPTY, List.of()).getNamedXContents() + ); + + private String mockedSearchResponseString; + private String mockedEmptySearchResponseString; + private AbstractRetrieverTool mockedImpl; + + @Before + @SneakyThrows + public void setup() { + try (InputStream searchResponseIns = AbstractRetrieverTool.class.getResourceAsStream("retrieval_tool_search_response.json")) { + if (searchResponseIns != null) { + mockedSearchResponseString = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); + } + } + try (InputStream searchResponseIns = AbstractRetrieverTool.class.getResourceAsStream("retrieval_tool_empty_search_response.json")) { + if (searchResponseIns != null) { + mockedEmptySearchResponseString = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); + } + } + + mockedImpl = Mockito + .mock( + AbstractRetrieverTool.class, + Mockito + .withSettings() + .useConstructor(null, TEST_XCONTENT_REGISTRY_FOR_QUERY, TEST_INDEX, TEST_SOURCE_FIELDS, TEST_DOC_SIZE) + .defaultAnswer(Mockito.CALLS_REAL_METHODS) + ); + when(mockedImpl.getQueryBody(any(String.class))).thenReturn(TEST_QUERY); + } + + @Test + @SneakyThrows + public void testRunAsyncWithSearchResults() { + Client client = mock(Client.class); + SearchResponse mockedSearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedSearchResponseString) + ); + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedSearchResponse); + return null; + }).when(client).search(any(), any()); + mockedImpl.setClient(client); + + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); + + mockedImpl.run(Map.of(AbstractRetrieverTool.INPUT_FIELD, "hello world"), listener); + + future.join(); + assertEquals( + "{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"Company test_mock have a history of 100 years.\"},\"_id\":\"1\",\"_score\":89.2917}\n" + + "{\"_index\":\"hybrid-index\",\"_source\":{\"passage_text\":\"the price of the api is 2$ per invocation\"},\"_id\":\"2\",\"_score\":0.10702579}\n", + future.get() + ); + } + + @Test + @SneakyThrows + public void testRunAsyncWithEmptySearchResponse() { + Client client = mock(Client.class); + SearchResponse mockedEmptySearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedEmptySearchResponseString) + ); + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedEmptySearchResponse); + return null; + }).when(client).search(any(), any()); + mockedImpl.setClient(client); + + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(r -> { future.complete(r); }, e -> { future.completeExceptionally(e); }); + + mockedImpl.run(Map.of(AbstractRetrieverTool.INPUT_FIELD, "hello world"), listener); + + future.join(); + assertEquals("Can not get any match from search result.", future.get()); + } + + @Test + @SneakyThrows + public void testRunAsyncWithIllegalQueryThenListenerOnFailure() { + Client client = mock(Client.class); + mockedImpl.setClient(client); + + final CompletableFuture future1 = new CompletableFuture<>(); + ActionListener listener1 = ActionListener.wrap(future1::complete, future1::completeExceptionally); + mockedImpl.run(Map.of(AbstractRetrieverTool.INPUT_FIELD, ""), listener1); + + Exception exception1 = assertThrows(Exception.class, future1::join); + assertTrue(exception1.getCause() instanceof IllegalArgumentException); + assertEquals(exception1.getCause().getMessage(), "[input] is null or empty, can not process it."); + + final CompletableFuture future2 = new CompletableFuture<>(); + ActionListener listener2 = ActionListener.wrap(future2::complete, future2::completeExceptionally); + mockedImpl.run(Map.of(AbstractRetrieverTool.INPUT_FIELD, " "), listener2); + + Exception exception2 = assertThrows(Exception.class, future2::join); + assertTrue(exception2.getCause() instanceof IllegalArgumentException); + assertEquals(exception2.getCause().getMessage(), "[input] is null or empty, can not process it."); + + final CompletableFuture future3 = new CompletableFuture<>(); + ActionListener listener3 = ActionListener.wrap(future3::complete, future3::completeExceptionally); + mockedImpl.run(Map.of("test", "hello world"), listener3); + + Exception exception3 = assertThrows(Exception.class, future3::join); + assertTrue(exception3.getCause() instanceof IllegalArgumentException); + assertEquals(exception3.getCause().getMessage(), "[input] is null or empty, can not process it."); + + final CompletableFuture future4 = new CompletableFuture<>(); + ActionListener listener4 = ActionListener.wrap(future4::complete, future4::completeExceptionally); + mockedImpl.run(null, listener4); + + Exception exception4 = assertThrows(Exception.class, future4::join); + assertTrue(exception4.getCause() instanceof NullPointerException); + } + + @Test + @SneakyThrows + public void testValidate() { + assertTrue(mockedImpl.validate(Map.of(AbstractRetrieverTool.INPUT_FIELD, "hi"))); + assertFalse(mockedImpl.validate(Map.of(AbstractRetrieverTool.INPUT_FIELD, ""))); + assertFalse(mockedImpl.validate(Map.of(AbstractRetrieverTool.INPUT_FIELD, " "))); + assertFalse(mockedImpl.validate(Map.of("test", " "))); + assertFalse(mockedImpl.validate(new HashMap<>())); + assertFalse(mockedImpl.validate(null)); + } + + @Test + public void testGetAttributes() { + assertEquals(mockedImpl.getVersion(), null); + assertEquals(mockedImpl.getIndex(), TEST_INDEX); + assertEquals(mockedImpl.getDocSize(), TEST_DOC_SIZE); + assertEquals(mockedImpl.getSourceFields(), TEST_SOURCE_FIELDS); + assertEquals(mockedImpl.getQueryBody(TEST_QUERY), TEST_QUERY); + } + + @Test + public void testGetQueryBodySuccess() { + assertEquals(mockedImpl.getQueryBody(TEST_QUERY), TEST_QUERY); + } + + @Test + @SneakyThrows + public void testRunWithRuntimeException() { + Client client = mock(Client.class); + mockedImpl.setClient(client); + ActionListener listener = mock(ActionListener.class); + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new RuntimeException("Failed to search index")); + return null; + }).when(client).search(any(), any()); + mockedImpl.run(Map.of(AbstractRetrieverTool.INPUT_FIELD, "hello world"), listener); + verify(listener).onFailure(any(RuntimeException.class)); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to search index", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testFactory() { + // Create a mock object of the abstract Factory class + Client client = mock(Client.class); + AbstractRetrieverTool.Factory factoryMock = new AbstractRetrieverTool.Factory<>() { + public AbstractRetrieverTool create(Map params) { + return null; + } + + @Override + public String getDefaultType() { + return null; + } + + @Override + public String getDefaultVersion() { + return null; + } + }; + + factoryMock.init(client, TEST_XCONTENT_REGISTRY_FOR_QUERY); + + assertNotNull(factoryMock.client); + assertNotNull(factoryMock.xContentRegistry); + assertEquals(client, factoryMock.client); + assertEquals(TEST_XCONTENT_REGISTRY_FOR_QUERY, factoryMock.xContentRegistry); + + String defaultDescription = factoryMock.getDefaultDescription(); + assertEquals(DEFAULT_DESCRIPTION, defaultDescription); + } +} diff --git a/src/test/java/org/opensearch/agent/tools/CreateAnomalyDetectorToolTests.java b/src/test/java/org/opensearch/agent/tools/CreateAnomalyDetectorToolTests.java new file mode 100644 index 00000000..0749ab70 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/CreateAnomalyDetectorToolTests.java @@ -0,0 +1,280 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; + +import com.google.common.collect.ImmutableMap; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class CreateAnomalyDetectorToolTests { + @Mock + private Client client; + @Mock + private AdminClient adminClient; + @Mock + private IndicesAdminClient indicesAdminClient; + @Mock + private GetMappingsResponse getMappingsResponse; + @Mock + private MappingMetadata mappingMetadata; + private Map mockedMappings; + private Map indexMappings; + + @Mock + private MLTaskResponse mlTaskResponse; + @Mock + private ModelTensorOutput modelTensorOutput; + @Mock + private ModelTensors modelTensors; + + private ModelTensor modelTensor; + + private Map modelReturns; + + private String mockedIndexName = "http_logs"; + private String mockedResponse = "{category_field=|aggregation_field=response,responseLatency|aggregation_method=count,avg}"; + private String mockedResult = + "{\"index\":\"http_logs\",\"categoryField\":\"\",\"aggregationField\":\"response,responseLatency\",\"aggregationMethod\":\"count,avg\",\"dateFields\":\"date\"}"; + + private String mockedResultForIndexPattern = + "{\"index\":\"http_logs*\",\"categoryField\":\"\",\"aggregationField\":\"response,responseLatency\",\"aggregationMethod\":\"count,avg\",\"dateFields\":\"date\"}"; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + createMappings(); + // get mapping + when(mappingMetadata.getSourceAsMap()).thenReturn(indexMappings); + when(getMappingsResponse.getMappings()).thenReturn(mockedMappings); + when(client.admin()).thenReturn(adminClient); + when(adminClient.indices()).thenReturn(indicesAdminClient); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(getMappingsResponse); + return null; + }).when(indicesAdminClient).getMappings(any(), any()); + + initMLTensors(); + CreateAnomalyDetectorTool.Factory.getInstance().init(client); + } + + @Test + public void testModelIdIsNullOrEmpty() { + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "")) + ); + assertEquals("model_id cannot be empty.", exception.getMessage()); + } + + @Test + public void testModelType() { + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "model_type", "unknown")) + ); + assertEquals("Unsupported model_type: unknown", exception.getMessage()); + + CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "model_type", "openai")); + assertEquals(CreateAnomalyDetectorTool.TYPE, tool.getName()); + assertEquals("modelId", tool.getModelId()); + assertEquals("OPENAI", tool.getModelType().toString()); + + tool = CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "model_type", "claude")); + assertEquals(CreateAnomalyDetectorTool.TYPE, tool.getName()); + assertEquals("modelId", tool.getModelId()); + assertEquals("CLAUDE", tool.getModelType().toString()); + } + + @Test + public void testTool() { + CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId")); + assertEquals(CreateAnomalyDetectorTool.TYPE, tool.getName()); + assertEquals("modelId", tool.getModelId()); + assertEquals("CLAUDE", tool.getModelType().toString()); + + tool + .run( + ImmutableMap.of("index", mockedIndexName), + ActionListener.wrap(response -> assertEquals(mockedResult, response), log::info) + ); + tool + .run( + ImmutableMap.of("index", mockedIndexName + "*"), + ActionListener.wrap(response -> assertEquals(mockedResultForIndexPattern, response), log::info) + ); + tool + .run( + ImmutableMap.of("input", mockedIndexName), + ActionListener.wrap(response -> assertEquals(mockedResult, response), log::info) + ); + tool + .run( + ImmutableMap.of("input", gson.toJson(ImmutableMap.of("index", mockedIndexName))), + ActionListener.wrap(response -> assertEquals(mockedResult, response), log::info) + ); + } + + @Test + public void testToolWithInvalidResponse() { + CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId")); + + modelReturns = Collections.singletonMap("response", ""); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + initMLTensors(); + + Exception exception = assertThrows( + IllegalStateException.class, + () -> tool + .run(ImmutableMap.of("index", mockedIndexName), ActionListener.wrap(response -> assertEquals(response, ""), e -> { + throw new IllegalStateException(e.getMessage()); + })) + ); + assertEquals("Remote endpoint fails to inference, no response found.", exception.getMessage()); + + modelReturns = Collections.singletonMap("response", "not valid response"); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + initMLTensors(); + + exception = assertThrows( + IllegalStateException.class, + () -> tool + .run( + ImmutableMap.of("index", mockedIndexName), + ActionListener.wrap(response -> assertEquals(response, "not valid response"), e -> { + throw new IllegalStateException(e.getMessage()); + }) + ) + ); + assertEquals( + "The inference result from remote endpoint is not valid, cannot extract the key information from the result.", + exception.getMessage() + ); + + modelReturns = Collections.singletonMap("response", null); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + initMLTensors(); + + exception = assertThrows( + IllegalStateException.class, + () -> tool + .run(ImmutableMap.of("index", mockedIndexName), ActionListener.wrap(response -> assertEquals(response, ""), e -> { + throw new IllegalStateException(e.getMessage()); + })) + ); + assertEquals("Remote endpoint fails to inference, no response found.", exception.getMessage()); + } + + @Test + public void testToolWithSystemIndex() { + CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId")); + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> tool.run(ImmutableMap.of("index", ML_CONNECTOR_INDEX), ActionListener.wrap(result -> {}, e -> {})) + ); + assertEquals( + "CreateAnomalyDetectionTool doesn't support searching indices starting with '.' since it could be system index, current searching index name: " + + ML_CONNECTOR_INDEX, + exception.getMessage() + ); + } + + @Test + public void testToolWithGetMappingFailed() { + CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId")); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onFailure(new Exception("No mapping found for the index: " + mockedIndexName)); + return null; + }).when(indicesAdminClient).getMappings(any(), any()); + + tool.run(ImmutableMap.of("index", mockedIndexName), ActionListener.wrap(result -> {}, e -> { + assertEquals("No mapping found for the index: " + mockedIndexName, e.getMessage()); + })); + } + + @Test + public void testToolWithPredictModelFailed() { + CreateAnomalyDetectorTool tool = CreateAnomalyDetectorTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId")); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new Exception("predict model failed")); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + + tool.run(ImmutableMap.of("index", mockedIndexName), ActionListener.wrap(result -> {}, e -> { + assertEquals("predict model failed", e.getMessage()); + })); + } + + private void createMappings() { + indexMappings = new HashMap<>(); + indexMappings + .put( + "properties", + ImmutableMap + .of( + "response", + ImmutableMap.of("type", "integer"), + "responseLatency", + ImmutableMap.of("type", "float"), + "date", + ImmutableMap.of("type", "date") + ) + ); + mockedMappings = new HashMap<>(); + mockedMappings.put(mockedIndexName, mappingMetadata); + + modelReturns = Collections.singletonMap("response", mockedResponse); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, modelReturns); + } + + private void initMLTensors() { + when(modelTensors.getMlModelTensors()).thenReturn(Collections.singletonList(modelTensor)); + when(modelTensorOutput.getMlModelOutputs()).thenReturn(Collections.singletonList(modelTensors)); + when(mlTaskResponse.getOutput()).thenReturn(modelTensorOutput); + + // call model + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + } +} diff --git a/src/test/java/org/opensearch/agent/tools/NeuralSparseSearchToolTests.java b/src/test/java/org/opensearch/agent/tools/NeuralSparseSearchToolTests.java new file mode 100644 index 00000000..d6d14991 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/NeuralSparseSearchToolTests.java @@ -0,0 +1,146 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; + +import com.google.gson.JsonSyntaxException; + +import lombok.SneakyThrows; + +public class NeuralSparseSearchToolTests { + public static final String TEST_QUERY_TEXT = "123fsd23134sdfouh"; + public static final String TEST_EMBEDDING_FIELD = "test embedding"; + public static final String TEST_MODEL_ID = "123fsd23134"; + public static final String TEST_NESTED_PATH = "nested_path"; + private Map params = new HashMap<>(); + + @Before + public void setup() { + params.put(NeuralSparseSearchTool.INDEX_FIELD, AbstractRetrieverToolTests.TEST_INDEX); + params.put(NeuralSparseSearchTool.EMBEDDING_FIELD, TEST_EMBEDDING_FIELD); + params.put(NeuralSparseSearchTool.SOURCE_FIELD, gson.toJson(AbstractRetrieverToolTests.TEST_SOURCE_FIELDS)); + params.put(NeuralSparseSearchTool.MODEL_ID_FIELD, TEST_MODEL_ID); + params.put(NeuralSparseSearchTool.DOC_SIZE_FIELD, AbstractRetrieverToolTests.TEST_DOC_SIZE.toString()); + } + + @Test + @SneakyThrows + public void testCreateTool() { + NeuralSparseSearchTool tool = NeuralSparseSearchTool.Factory.getInstance().create(params); + assertEquals(AbstractRetrieverToolTests.TEST_INDEX, tool.getIndex()); + assertEquals(TEST_EMBEDDING_FIELD, tool.getEmbeddingField()); + assertEquals(AbstractRetrieverToolTests.TEST_SOURCE_FIELDS, tool.getSourceFields()); + assertEquals(TEST_MODEL_ID, tool.getModelId()); + assertEquals(AbstractRetrieverToolTests.TEST_DOC_SIZE, tool.getDocSize()); + assertEquals("NeuralSparseSearchTool", tool.getType()); + assertEquals("NeuralSparseSearchTool", tool.getName()); + assertEquals( + "Use this tool to search data in OpenSearch index.", + NeuralSparseSearchTool.Factory.getInstance().getDefaultDescription() + ); + } + + @Test + @SneakyThrows + public void testGetQueryBody() { + NeuralSparseSearchTool tool = NeuralSparseSearchTool.Factory.getInstance().create(params); + Map>>> queryBody = gson.fromJson(tool.getQueryBody(TEST_QUERY_TEXT), Map.class); + assertEquals("123fsd23134sdfouh", queryBody.get("query").get("neural_sparse").get("test embedding").get("query_text")); + assertEquals("123fsd23134", queryBody.get("query").get("neural_sparse").get("test embedding").get("model_id")); + } + + @Test + @SneakyThrows + public void testGetQueryBodyWithNestedPath() { + params.put(NeuralSparseSearchTool.NESTED_PATH_FIELD, TEST_NESTED_PATH); + NeuralSparseSearchTool tool = NeuralSparseSearchTool.Factory.getInstance().create(params); + Map>> nestedQueryBody = gson.fromJson(tool.getQueryBody(TEST_QUERY_TEXT), Map.class); + assertEquals("nested_path", nestedQueryBody.get("query").get("nested").get("path")); + assertEquals("max", nestedQueryBody.get("query").get("nested").get("score_mode")); + Map>> queryBody = (Map>>) nestedQueryBody + .get("query") + .get("nested") + .get("query"); + assertEquals("123fsd23134sdfouh", queryBody.get("neural_sparse").get("test embedding").get("query_text")); + assertEquals("123fsd23134", queryBody.get("neural_sparse").get("test embedding").get("model_id")); + } + + @Test + @SneakyThrows + public void testGetQueryBodyWithJsonObjectString() { + NeuralSparseSearchTool tool = NeuralSparseSearchTool.Factory.getInstance().create(params); + String jsonInput = gson.toJson(Map.of("hi", "a")); + Map>>> queryBody = gson.fromJson(tool.getQueryBody(jsonInput), Map.class); + assertEquals("{\"hi\":\"a\"}", queryBody.get("query").get("neural_sparse").get("test embedding").get("query_text")); + assertEquals("123fsd23134", queryBody.get("query").get("neural_sparse").get("test embedding").get("model_id")); + } + + @Test + @SneakyThrows + public void testGetQueryBodyWithIllegalParams() { + Map illegalParams1 = new HashMap<>(params); + illegalParams1.remove(NeuralSparseSearchTool.MODEL_ID_FIELD); + NeuralSparseSearchTool tool1 = NeuralSparseSearchTool.Factory.getInstance().create(illegalParams1); + Exception exception1 = assertThrows( + IllegalArgumentException.class, + () -> tool1.getQueryBody(AbstractRetrieverToolTests.TEST_QUERY) + ); + assertEquals("Parameter [embedding_field] and [model_id] can not be null or empty.", exception1.getMessage()); + + Map illegalParams2 = new HashMap<>(params); + illegalParams2.remove(NeuralSparseSearchTool.EMBEDDING_FIELD); + NeuralSparseSearchTool tool2 = NeuralSparseSearchTool.Factory.getInstance().create(illegalParams2); + Exception exception2 = assertThrows( + IllegalArgumentException.class, + () -> tool2.getQueryBody(AbstractRetrieverToolTests.TEST_QUERY) + ); + assertEquals("Parameter [embedding_field] and [model_id] can not be null or empty.", exception2.getMessage()); + } + + @Test + @SneakyThrows + public void testCreateToolsParseParams() { + assertThrows( + ClassCastException.class, + () -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(NeuralSparseSearchTool.INDEX_FIELD, 123)) + ); + + assertThrows( + ClassCastException.class, + () -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(NeuralSparseSearchTool.EMBEDDING_FIELD, 123)) + ); + + assertThrows( + ClassCastException.class, + () -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(NeuralSparseSearchTool.MODEL_ID_FIELD, 123)) + ); + + assertThrows( + ClassCastException.class, + () -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(NeuralSparseSearchTool.NESTED_PATH_FIELD, 123)) + ); + + assertThrows( + JsonSyntaxException.class, + () -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(NeuralSparseSearchTool.SOURCE_FIELD, "123")) + ); + + // although it will be parsed as integer, but the parameters value should always be String + assertThrows( + ClassCastException.class, + () -> NeuralSparseSearchTool.Factory.getInstance().create(Map.of(NeuralSparseSearchTool.DOC_SIZE_FIELD, 123)) + ); + } +} diff --git a/src/test/java/org/opensearch/agent/tools/PPLToolTests.java b/src/test/java/org/opensearch/agent/tools/PPLToolTests.java new file mode 100644 index 00000000..ae1baa31 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/PPLToolTests.java @@ -0,0 +1,444 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.admin.indices.mapping.get.GetMappingsResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.sql.plugin.transport.PPLQueryAction; +import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; + +import com.google.common.collect.ImmutableMap; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class PPLToolTests { + @Mock + private Client client; + @Mock + private AdminClient adminClient; + @Mock + private IndicesAdminClient indicesAdminClient; + @Mock + private GetMappingsResponse getMappingsResponse; + @Mock + private MappingMetadata mappingMetadata; + private Map mockedMappings; + private Map indexMappings; + + private SearchHits searchHits; + + private SearchHit hit; + @Mock + private SearchResponse searchResponse; + + private Map sampleMapping; + + @Mock + private MLTaskResponse mlTaskResponse; + @Mock + private ModelTensorOutput modelTensorOutput; + @Mock + private ModelTensors modelTensors; + + private ModelTensor modelTensor; + + private Map pplReturns; + + @Mock + private TransportPPLQueryResponse transportPPLQueryResponse; + + private String mockedIndexName = "demo"; + + private String pplResult = "ppl result"; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + createMappings(); + // get mapping + when(mappingMetadata.getSourceAsMap()).thenReturn(indexMappings); + when(getMappingsResponse.getMappings()).thenReturn(mockedMappings); + when(client.admin()).thenReturn(adminClient); + when(adminClient.indices()).thenReturn(indicesAdminClient); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(getMappingsResponse); + return null; + }).when(indicesAdminClient).getMappings(any(), any()); + // mockedMappings (index name, mappingmetadata) + + // search result + + when(searchResponse.getHits()).thenReturn(searchHits); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + + initMLTensors(); + + when(transportPPLQueryResponse.getResult()).thenReturn(pplResult); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(transportPPLQueryResponse); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + PPLTool.Factory.getInstance().init(client); + } + + @Test + public void testTool_WithoutModelId() { + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> PPLTool.Factory.getInstance().create(ImmutableMap.of("prompt", "contextPrompt")) + ); + assertEquals("PPL tool needs non blank model id.", exception.getMessage()); + } + + @Test + public void testTool_WithBlankModelId() { + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", " ")) + ); + assertEquals("PPL tool needs non blank model id.", exception.getMessage()); + } + + @Test + public void testTool_WithNonIntegerHead() { + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "demo", "head", "11.5")) + ); + assertEquals("PPL tool parameter head must be integer.", exception.getMessage()); + } + + @Test + public void testTool_WithNonBooleanExecute() { + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "demo", "execute", "hello")) + ); + assertEquals("PPL tool parameter execute must be false or true", exception.getMessage()); + } + + @Test + public void testTool() { + PPLTool tool = PPLTool.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt", "head", "100")); + assertEquals(PPLTool.TYPE, tool.getName()); + + tool.run(ImmutableMap.of("index", "demo", "question", "demo"), ActionListener.wrap(executePPLResult -> { + Map returnResults = gson.fromJson(executePPLResult, Map.class); + assertEquals("ppl result", returnResults.get("executionResult")); + assertEquals("source=demo| head 1", returnResults.get("ppl")); + }, e -> { log.info(e); })); + + } + + @Test + public void testTool_withPreviousInput() { + PPLTool tool = PPLTool.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt", "previous_tool_name", "previousTool", "head", "-5")); + assertEquals(PPLTool.TYPE, tool.getName()); + + tool.run(ImmutableMap.of("previousTool.output", "demo", "question", "demo"), ActionListener.wrap(executePPLResult -> { + Map returnResults = gson.fromJson(executePPLResult, Map.class); + assertEquals("ppl result", returnResults.get("executionResult")); + assertEquals("source=demo| head 1", returnResults.get("ppl")); + }, e -> { log.info(e); })); + + } + + @Test + public void testTool_withHEADButIgnore() { + PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt", "head", "5")); + assertEquals(PPLTool.TYPE, tool.getName()); + + tool.run(ImmutableMap.of("index", "demo", "question", "demo"), ActionListener.wrap(executePPLResult -> { + Map returnResults = gson.fromJson(executePPLResult, Map.class); + assertEquals("ppl result", returnResults.get("executionResult")); + assertEquals("source=demo| head 1", returnResults.get("ppl")); + }, e -> { log.info(e); })); + + } + + @Test + public void testTool_withHEAD() { + pplReturns = Collections.singletonMap("response", "source=demo"); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, pplReturns); + initMLTensors(); + + PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt", "head", "5")); + assertEquals(PPLTool.TYPE, tool.getName()); + + tool.run(ImmutableMap.of("index", "demo", "question", "demo"), ActionListener.wrap(executePPLResult -> { + Map returnResults = gson.fromJson(executePPLResult, Map.class); + assertEquals("ppl result", returnResults.get("executionResult")); + assertEquals("source=demo | head 5", returnResults.get("ppl")); + }, e -> { log.info(e); })); + + } + + @Test + public void testTool_with_WithoutExecution() { + PPLTool tool = PPLTool.Factory + .getInstance() + .create(ImmutableMap.of("model_id", "modelId", "model_type", "claude", "execute", "false")); + assertEquals(PPLTool.TYPE, tool.getName()); + + tool.run(ImmutableMap.of("index", "demo", "question", "demo"), ActionListener.wrap(executePPLResult -> { + Map ret = gson.fromJson(executePPLResult, Map.class); + assertEquals("source=demo| head 1", ret.get("ppl")); + }, e -> { log.info(e); })); + + } + + @Test + public void testTool_with_DefaultPrompt() { + PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "model_type", "claude")); + assertEquals(PPLTool.TYPE, tool.getName()); + + tool.run(ImmutableMap.of("index", "demo", "question", "demo"), ActionListener.wrap(executePPLResult -> { + Map returnResults = gson.fromJson(executePPLResult, Map.class); + assertEquals("ppl result", returnResults.get("executionResult")); + assertEquals("source=demo| head 1", returnResults.get("ppl")); + }, e -> { log.info(e); })); + + } + + @Test + public void testTool_withPPLTag() { + PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); + assertEquals(PPLTool.TYPE, tool.getName()); + + pplReturns = Collections.singletonMap("response", "source=demo\n|\n\rhead 1"); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, pplReturns); + initMLTensors(); + + tool.run(ImmutableMap.of("index", "demo", "question", "demo"), ActionListener.wrap(executePPLResult -> { + Map returnResults = gson.fromJson(executePPLResult, Map.class); + assertEquals("ppl result", returnResults.get("executionResult")); + assertEquals("source=demo|head 1", returnResults.get("ppl")); + }, e -> { log.info(e); })); + + } + + @Test + public void testTool_withDescribeStartPPL() { + PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); + assertEquals(PPLTool.TYPE, tool.getName()); + + pplReturns = Collections.singletonMap("response", "describe demo"); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, pplReturns); + initMLTensors(); + + tool.run(ImmutableMap.of("index", "demo", "question", "demo"), ActionListener.wrap(executePPLResult -> { + Map returnResults = gson.fromJson(executePPLResult, Map.class); + assertEquals("ppl result", returnResults.get("executionResult")); + assertEquals("describe demo", returnResults.get("ppl")); + }, e -> { log.info(e); })); + + } + + @Test + public void testTool_querySystemIndex() { + PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); + assertEquals(PPLTool.TYPE, tool.getName()); + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> tool.run(ImmutableMap.of("index", ML_CONNECTOR_INDEX, "question", "demo"), ActionListener.wrap(ppl -> { + assertEquals(pplResult, "ppl result"); + }, e -> { assertEquals("We cannot search system indices " + ML_CONNECTOR_INDEX, e.getMessage()); })) + ); + assertEquals( + "PPLTool doesn't support searching indices starting with '.' since it could be system index, current searching index name: " + + ML_CONNECTOR_INDEX, + exception.getMessage() + ); + } + + @Test + public void testTool_queryEmptyIndex() { + PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); + assertEquals(PPLTool.TYPE, tool.getName()); + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> tool.run(ImmutableMap.of("question", "demo"), ActionListener.wrap(ppl -> { + assertEquals(pplResult, "ppl result"); + }, e -> { assertEquals("We cannot search system indices " + ML_CONNECTOR_INDEX, e.getMessage()); })) + ); + assertEquals( + "Return this final answer to human directly and do not use other tools: 'Please provide index name'. Please try to directly send this message to human to ask for index name", + exception.getMessage() + ); + } + + @Test + public void testTool_WrongModelType() { + PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "model_type", "wrong_model_type")); + assertEquals(PPLTool.PPLModelType.CLAUDE, tool.getPplModelType()); + } + + @Test + public void testTool_getMappingFailure() { + PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); + assertEquals(PPLTool.TYPE, tool.getName()); + Exception exception = new Exception("get mapping error"); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onFailure(exception); + return null; + }).when(indicesAdminClient).getMappings(any(), any()); + + tool + .run( + ImmutableMap.of("index", "demo", "question", "demo"), + ActionListener.wrap(ppl -> { assertEquals(pplResult, "ppl result"); }, e -> { + assertEquals("get mapping error", e.getMessage()); + }) + ); + } + + @Test + public void testTool_predictModelFailure() { + PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); + assertEquals(PPLTool.TYPE, tool.getName()); + Exception exception = new Exception("predict model error"); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(exception); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + + tool + .run( + ImmutableMap.of("index", "demo", "question", "demo"), + ActionListener.wrap(ppl -> { assertEquals(pplResult, "ppl result"); }, e -> { + assertEquals("predict model error", e.getMessage()); + }) + ); + } + + @Test + public void testTool_searchFailure() { + PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); + assertEquals(PPLTool.TYPE, tool.getName()); + Exception exception = new Exception("search error"); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onFailure(exception); + return null; + }).when(client).search(any(), any()); + + tool + .run( + ImmutableMap.of("index", "demo", "question", "demo"), + ActionListener.wrap(ppl -> { assertEquals(pplResult, "ppl result"); }, e -> { + assertEquals("search error", e.getMessage()); + }) + ); + } + + @Test + public void testTool_executePPLFailure() { + PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); + assertEquals(PPLTool.TYPE, tool.getName()); + Exception exception = new Exception("execute ppl error"); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(exception); + return null; + }).when(client).execute(eq(PPLQueryAction.INSTANCE), any(), any()); + + tool + .run( + ImmutableMap.of("index", "demo", "question", "demo"), + ActionListener.wrap(ppl -> { assertEquals(pplResult, "ppl result"); }, e -> { + assertEquals("execute ppl:source=demo| head 1, get error: execute ppl error", e.getMessage()); + }) + ); + } + + private void createMappings() { + indexMappings = new HashMap<>(); + indexMappings + .put( + "properties", + ImmutableMap + .of( + "demoFields", + ImmutableMap.of("type", "text"), + "demoNested", + ImmutableMap + .of( + "properties", + ImmutableMap.of("nest1", ImmutableMap.of("type", "text"), "nest2", ImmutableMap.of("type", "text")) + ) + ) + ); + mockedMappings = new HashMap<>(); + mockedMappings.put(mockedIndexName, mappingMetadata); + + BytesReference bytesArray = new BytesArray("{\"demoFields\":\"111\", \"demoNested\": {\"nest1\": \"222\", \"nest2\": \"333\"}}"); + hit = new SearchHit(1); + hit.sourceRef(bytesArray); + searchHits = new SearchHits(new SearchHit[] { hit }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0f); + pplReturns = Collections.singletonMap("response", "source=demo| head 1"); + modelTensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, pplReturns); + + } + + private void initMLTensors(){ + when(modelTensors.getMlModelTensors()).thenReturn(Collections.singletonList(modelTensor)); + when(modelTensorOutput.getMlModelOutputs()).thenReturn(Collections.singletonList(modelTensors)); + when(mlTaskResponse.getOutput()).thenReturn(modelTensorOutput); + + // call model + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + } +} diff --git a/src/test/java/org/opensearch/agent/tools/RAGToolTests.java b/src/test/java/org/opensearch/agent/tools/RAGToolTests.java new file mode 100644 index 00000000..0f19f91a --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/RAGToolTests.java @@ -0,0 +1,509 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.opensearch.agent.tools.AbstractRetrieverTool.*; +import static org.opensearch.agent.tools.AbstractRetrieverToolTests.*; +import static org.opensearch.agent.tools.VectorDBTool.DEFAULT_K; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.*; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.ParseField; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; + +import lombok.SneakyThrows; + +public class RAGToolTests { + public static final String TEST_QUERY_TEXT = "hello?"; + public static final String TEST_EMBEDDING_FIELD = "test_embedding"; + public static final String TEST_EMBEDDING_MODEL_ID = "1234"; + public static final String TEST_INFERENCE_MODEL_ID = "1234"; + public static final String TEST_NEURAL_QUERY_TYPE = "neural"; + public static final String TEST_NEURAL_SPARSE_QUERY_TYPE = "neural_sparse"; + public static final String TEST_NESTED_PATH = "nested_path"; + + static public final NamedXContentRegistry TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY = getQueryNamedXContentRegistry(); + private RAGTool ragTool; + private String mockedSearchResponseString; + private String mockedEmptySearchResponseString; + private String mockedNeuralSparseSearchResponseString; + @Mock + private Parser mockOutputParser; + @Mock + private Client client; + @Mock + private ActionListener listener; + private Map params; + + @Before + @SneakyThrows + public void setup() { + try (InputStream searchResponseIns = AbstractRetrieverTool.class.getResourceAsStream("retrieval_tool_search_response.json")) { + if (searchResponseIns != null) { + mockedSearchResponseString = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); + } + } + try (InputStream searchResponseIns = AbstractRetrieverTool.class.getResourceAsStream("retrieval_tool_empty_search_response.json")) { + if (searchResponseIns != null) { + mockedEmptySearchResponseString = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); + } + } + + try (InputStream searchResponseIns = AbstractRetrieverTool.class.getResourceAsStream("neural_sparse_tool_search_response.json")) { + if (searchResponseIns != null) { + mockedNeuralSparseSearchResponseString = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); + } + } + client = mock(Client.class); + listener = mock(ActionListener.class); + RAGTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY); + VectorDBTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY); + NeuralSparseSearchTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY); + params = new HashMap<>(); + params.put(RAGTool.INDEX_FIELD, TEST_INDEX); + params.put(RAGTool.EMBEDDING_FIELD, TEST_EMBEDDING_FIELD); + params.put(RAGTool.SOURCE_FIELD, gson.toJson(TEST_SOURCE_FIELDS)); + params.put(RAGTool.EMBEDDING_MODEL_ID_FIELD, TEST_EMBEDDING_MODEL_ID); + params.put(RAGTool.INFERENCE_MODEL_ID_FIELD, TEST_INFERENCE_MODEL_ID); + params.put(RAGTool.DOC_SIZE_FIELD, AbstractRetrieverToolTests.TEST_DOC_SIZE.toString()); + params.put(RAGTool.K_FIELD, DEFAULT_K.toString()); + params.put(RAGTool.QUERY_TYPE, TEST_NEURAL_QUERY_TYPE); + params.put(RAGTool.CONTENT_GENERATION_FIELD, "true"); + ragTool = RAGTool.Factory.getInstance().create(params); + } + + @Test + public void testValidate() { + assertTrue(ragTool.validate(Map.of(AbstractRetrieverTool.INPUT_FIELD, "hi"))); + assertFalse(ragTool.validate(Map.of(AbstractRetrieverTool.INPUT_FIELD, ""))); + assertFalse(ragTool.validate(Map.of(AbstractRetrieverTool.INPUT_FIELD, " "))); + assertFalse(ragTool.validate(Map.of("test", " "))); + assertFalse(ragTool.validate(new HashMap<>())); + assertFalse(ragTool.validate(null)); + } + + @Test + public void testGetAttributes() { + assertEquals(ragTool.getVersion(), null); + assertEquals(ragTool.getType(), RAGTool.TYPE); + assertEquals(ragTool.getInferenceModelId(), TEST_INFERENCE_MODEL_ID); + } + + @Test + public void testSetName() { + assertEquals(ragTool.getName(), RAGTool.TYPE); + ragTool.setName("test-tool"); + assertEquals(ragTool.getName(), "test-tool"); + } + + @Test + public void testOutputParser() throws IOException { + + NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry(); + ragTool.setXContentRegistry(mockNamedXContentRegistry); + + ModelTensorOutput mlModelTensorOutput = getMlModelTensorOutput(); + SearchResponse mockedSearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedSearchResponseString) + ); + + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedSearchResponse); + return null; + }).when(client).search(any(), any()); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + + ragTool.setOutputParser(mockOutputParser); + ragTool.run(Map.of(INPUT_FIELD, TEST_QUERY_TEXT), listener); + + verify(client).search(any(), any()); + verify(client).execute(any(), any(), any()); + } + + @Test + public void testRunWithEmptySearchResponse() throws IOException { + NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry(); + ragTool.setXContentRegistry(mockNamedXContentRegistry); + + ModelTensorOutput mlModelTensorOutput = getMlModelTensorOutput(); + SearchResponse mockedEmptySearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedEmptySearchResponseString) + ); + + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedEmptySearchResponse); + return null; + }).when(client).search(any(), any()); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + ragTool.run(Map.of(INPUT_FIELD, TEST_QUERY_TEXT), listener); + verify(client).search(any(), any()); + verify(client).execute(any(), any(), any()); + } + + @Test + public void testRunWithNeuralSparseQueryType() throws IOException { + + Map paramsWithNeuralSparse = new HashMap<>(params); + paramsWithNeuralSparse.put(RAGTool.QUERY_TYPE, TEST_NEURAL_SPARSE_QUERY_TYPE); + + RAGTool rAGtoolWithNeuralSparseQuery = RAGTool.Factory.getInstance().create(paramsWithNeuralSparse); + + NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry(); + rAGtoolWithNeuralSparseQuery.setXContentRegistry(mockNamedXContentRegistry); + + ModelTensorOutput mlModelTensorOutput = getMlModelTensorOutput(); + SearchResponse mockedNeuralSparseSearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + mockedNeuralSparseSearchResponseString + ) + ); + + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedNeuralSparseSearchResponse); + return null; + }).when(client).search(any(), any()); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + rAGtoolWithNeuralSparseQuery.run(Map.of(INPUT_FIELD, TEST_QUERY_TEXT), listener); + verify(client).search(any(), any()); + verify(client).execute(any(), any(), any()); + } + + @Test + public void testRunWithInvalidQueryType() throws IOException { + + RAGTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY); + Map paramsWithInvalidQueryType = new HashMap<>(params); + paramsWithInvalidQueryType.put(RAGTool.QUERY_TYPE, "sparse"); + try { + RAGTool rAGtoolWithInvalidQueryType = RAGTool.Factory.getInstance().create(paramsWithInvalidQueryType); + } catch (IllegalArgumentException e) { + assertEquals("Failed to read queryType, please input neural_sparse or neural.", e.getMessage()); + } + + } + + @Test + public void testRunWithQuestionJson() throws IOException { + NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry(); + ragTool.setXContentRegistry(mockNamedXContentRegistry); + + ModelTensorOutput mlModelTensorOutput = getMlModelTensorOutput(); + SearchResponse mockedEmptySearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedEmptySearchResponseString) + ); + + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedEmptySearchResponse); + return null; + }).when(client).search(any(), any()); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + ragTool.run(Map.of(INPUT_FIELD, "{question:'what is the population in seattle?'}"), listener); + verify(client).search(any(), any()); + verify(client).execute(any(), any(), any()); + } + + @Test + public void testRunEmptyResponseWithNotEnableContentGeneration() throws IOException { + ActionListener mockListener = mock(ActionListener.class); + Map paramsWithNotEnableContentGeneration = new HashMap<>(params); + paramsWithNotEnableContentGeneration.put(RAGTool.CONTENT_GENERATION_FIELD, "false"); + + RAGTool rAGtoolWithNotEnableContentGeneration = RAGTool.Factory.getInstance().create(paramsWithNotEnableContentGeneration); + + NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry(); + rAGtoolWithNotEnableContentGeneration.setXContentRegistry(mockNamedXContentRegistry); + + SearchResponse mockedEmptySearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedEmptySearchResponseString) + ); + + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedEmptySearchResponse); + return null; + }).when(client).search(any(), any()); + rAGtoolWithNotEnableContentGeneration.run(Map.of(INPUT_FIELD, "{question:'what is the population in seattle?'}"), mockListener); + + verify(client).search(any(), any()); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(mockListener).onResponse(responseCaptor.capture()); + assertEquals("Can not get any match from search result.", responseCaptor.getValue()); + + } + + @Test + public void testRunResponseWithNotEnableContentGeneration() throws IOException { + ActionListener mockListener = mock(ActionListener.class); + Map paramsWithNotEnableContentGeneration = new HashMap<>(params); + paramsWithNotEnableContentGeneration.put(RAGTool.CONTENT_GENERATION_FIELD, "false"); + + RAGTool rAGtoolWithNotEnableContentGeneration = RAGTool.Factory.getInstance().create(paramsWithNotEnableContentGeneration); + + NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry(); + rAGtoolWithNotEnableContentGeneration.setXContentRegistry(mockNamedXContentRegistry); + + SearchResponse mockedNeuralSparseSearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.IGNORE_DEPRECATIONS, + mockedNeuralSparseSearchResponseString + ) + ); + + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedNeuralSparseSearchResponse); + return null; + }).when(client).search(any(), any()); + rAGtoolWithNotEnableContentGeneration.run(Map.of(INPUT_FIELD, "{question:'what is the population in seattle?'}"), mockListener); + + verify(client).search(any(), any()); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(mockListener).onResponse(responseCaptor.capture()); + assertEquals( + "{\"_index\":\"my-nlp-index\",\"_source\":{\"passage_text\":\"Hello world\",\"passage_embedding\":{\"!\":0.8708904,\"door\":0.8587369,\"hi\":2.3929274,\"worlds\":2.7839446,\"yes\":0.75845814,\"##world\":2.5432441,\"born\":0.2682308,\"nothing\":0.8625516,\"goodbye\":0.17146169,\"greeting\":0.96817183,\"birth\":1.2788506,\"come\":0.1623208,\"global\":0.4371151,\"it\":0.42951578,\"life\":1.5750692,\"thanks\":0.26481047,\"world\":4.7300377,\"tiny\":0.5462298,\"earth\":2.6555297,\"universe\":2.0308156,\"worldwide\":1.3903781,\"hello\":6.696973,\"so\":0.20279501,\"?\":0.67785245},\"id\":\"s1\"},\"_id\":\"1\",\"_score\":30.0029}\n" + + "{\"_index\":\"my-nlp-index\",\"_source\":{\"passage_text\":\"Hi planet\",\"passage_embedding\":{\"hi\":4.338913,\"planets\":2.7755864,\"planet\":5.0969057,\"mars\":1.7405145,\"earth\":2.6087382,\"hello\":3.3210192},\"id\":\"s2\"},\"_id\":\"2\",\"_score\":16.480486}\n", + responseCaptor.getValue() + ); + + } + + @Test + @SneakyThrows + public void testRunWithRuntimeExceptionDuringSearch() { + NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry(); + ragTool.setXContentRegistry(mockNamedXContentRegistry); + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new RuntimeException("Failed to search index")); + return null; + }).when(client).search(any(), any()); + ragTool.run(Map.of(INPUT_FIELD, TEST_QUERY_TEXT), listener); + verify(listener).onFailure(any(RuntimeException.class)); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to search index", argumentCaptor.getValue().getMessage()); + } + + @Test + @SneakyThrows + public void testRunWithRuntimeExceptionDuringExecute() { + NamedXContentRegistry mockNamedXContentRegistry = getQueryNamedXContentRegistry(); + ragTool.setXContentRegistry(mockNamedXContentRegistry); + + SearchResponse mockedSearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedSearchResponseString) + ); + + doAnswer(invocation -> { + SearchRequest searchRequest = invocation.getArgument(0); + assertEquals((long) TEST_DOC_SIZE, (long) searchRequest.source().size()); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedSearchResponse); + return null; + }).when(client).search(any(), any()); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("Failed to run model " + TEST_INFERENCE_MODEL_ID)); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + + ragTool.run(Map.of(INPUT_FIELD, TEST_QUERY_TEXT), listener); + verify(listener).onFailure(any(RuntimeException.class)); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to run model " + TEST_INFERENCE_MODEL_ID, argumentCaptor.getValue().getMessage()); + } + + @Test(expected = IllegalArgumentException.class) + public void testRunWithEmptyInput() { + ActionListener listener = mock(ActionListener.class); + ragTool.run(Map.of(INPUT_FIELD, ""), listener); + } + + @Test + public void testFactoryNeuralQuery() { + RAGTool.Factory factoryMock = new RAGTool.Factory(); + RAGTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY); + factoryMock.init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY); + + String defaultDescription = factoryMock.getDefaultDescription(); + assertEquals(RAGTool.DEFAULT_DESCRIPTION, defaultDescription); + assertEquals(factoryMock.getDefaultType(), RAGTool.TYPE); + assertEquals(factoryMock.getDefaultVersion(), null); + assertNotNull(RAGTool.Factory.getInstance()); + + params.put(VectorDBTool.NESTED_PATH_FIELD, TEST_NESTED_PATH); + RAGTool rAGtool1 = factoryMock.create(params); + VectorDBTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY); + params.put(VectorDBTool.MODEL_ID_FIELD, TEST_EMBEDDING_MODEL_ID); + VectorDBTool queryTool = VectorDBTool.Factory.getInstance().create(params); + RAGTool rAGtool2 = new RAGTool(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY, TEST_INFERENCE_MODEL_ID, true, queryTool); + + assertEquals(rAGtool1.getClient(), rAGtool2.getClient()); + assertEquals(rAGtool1.getInferenceModelId(), rAGtool2.getInferenceModelId()); + assertEquals(rAGtool1.getName(), rAGtool2.getName()); + assertEquals(rAGtool1.getQueryTool().getDocSize(), rAGtool2.getQueryTool().getDocSize()); + assertEquals(rAGtool1.getQueryTool().getIndex(), rAGtool2.getQueryTool().getIndex()); + assertEquals(rAGtool1.getQueryTool().getSourceFields(), rAGtool2.getQueryTool().getSourceFields()); + assertEquals(rAGtool1.getXContentRegistry(), rAGtool2.getXContentRegistry()); + assertEquals(rAGtool1.getQueryType(), rAGtool2.getQueryType()); + assertEquals(((VectorDBTool) rAGtool1.getQueryTool()).getNestedPath(), ((VectorDBTool) rAGtool2.getQueryTool()).getNestedPath()); + } + + @Test + public void testFactoryNeuralSparseQuery() { + RAGTool.Factory factoryMock = new RAGTool.Factory(); + RAGTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY); + factoryMock.init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY); + + String defaultDescription = factoryMock.getDefaultDescription(); + assertEquals(RAGTool.DEFAULT_DESCRIPTION, defaultDescription); + assertNotNull(RAGTool.Factory.getInstance()); + assertEquals(factoryMock.getDefaultType(), RAGTool.TYPE); + assertEquals(factoryMock.getDefaultVersion(), null); + + params.put(NeuralSparseSearchTool.NESTED_PATH_FIELD, TEST_NESTED_PATH); + params.put("query_type", "neural_sparse"); + RAGTool rAGtool1 = factoryMock.create(params); + NeuralSparseSearchTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY); + NeuralSparseSearchTool queryTool = NeuralSparseSearchTool.Factory.getInstance().create(params); + RAGTool rAGtool2 = new RAGTool(client, TEST_XCONTENT_REGISTRY_FOR_NEURAL_QUERY, TEST_INFERENCE_MODEL_ID, true, queryTool); + + assertEquals(rAGtool1.getClient(), rAGtool2.getClient()); + assertEquals(rAGtool1.getInferenceModelId(), rAGtool2.getInferenceModelId()); + assertEquals(rAGtool1.getName(), rAGtool2.getName()); + assertEquals(rAGtool1.getQueryTool().getDocSize(), rAGtool2.getQueryTool().getDocSize()); + assertEquals(rAGtool1.getQueryTool().getIndex(), rAGtool2.getQueryTool().getIndex()); + assertEquals(rAGtool1.getQueryTool().getSourceFields(), rAGtool2.getQueryTool().getSourceFields()); + assertEquals(rAGtool1.getXContentRegistry(), rAGtool2.getXContentRegistry()); + assertEquals(rAGtool1.getQueryType(), rAGtool2.getQueryType()); + assertEquals( + ((NeuralSparseSearchTool) rAGtool1.getQueryTool()).getNestedPath(), + ((NeuralSparseSearchTool) rAGtool2.getQueryTool()).getNestedPath() + ); + } + + private static NamedXContentRegistry getQueryNamedXContentRegistry() { + QueryBuilder matchAllQueryBuilder = new MatchAllQueryBuilder(); + + List entries = new ArrayList<>(); + NamedXContentRegistry.Entry neural_query_entry = new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField("neural"), + (p, c) -> { + p.map(); + return matchAllQueryBuilder; + } + ); + entries.add(neural_query_entry); + NamedXContentRegistry.Entry neural_sparse_query_entry = new NamedXContentRegistry.Entry( + QueryBuilder.class, + new ParseField("neural_sparse"), + (p, c) -> { + p.map(); + return matchAllQueryBuilder; + } + ); + entries.add(neural_sparse_query_entry); + NamedXContentRegistry mockNamedXContentRegistry = new NamedXContentRegistry(entries); + return mockNamedXContentRegistry; + } + + private static ModelTensorOutput getMlModelTensorOutput() { + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("thought", "thought 1", "action", "action1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + return mlModelTensorOutput; + } +} diff --git a/src/test/java/org/opensearch/agent/tools/SearchAlertsToolTests.java b/src/test/java/org/opensearch/agent/tools/SearchAlertsToolTests.java new file mode 100644 index 00000000..ac1a8b3b --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/SearchAlertsToolTests.java @@ -0,0 +1,200 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionType; +import org.opensearch.client.AdminClient; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.client.node.NodeClient; +import org.opensearch.commons.alerting.action.GetAlertsResponse; +import org.opensearch.commons.alerting.model.Alert; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.spi.tools.Tool; + +public class SearchAlertsToolTests { + @Mock + private NodeClient nodeClient; + @Mock + private AdminClient adminClient; + @Mock + private IndicesAdminClient indicesAdminClient; + @Mock + private ClusterAdminClient clusterAdminClient; + + private Map nullParams; + private Map emptyParams; + private Map nonEmptyParams; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + SearchAlertsTool.Factory.getInstance().init(nodeClient); + + nullParams = null; + emptyParams = Collections.emptyMap(); + nonEmptyParams = Map.of("searchString", "foo"); + } + + @Test + public void testRunWithNoAlerts() throws Exception { + Tool tool = SearchAlertsTool.Factory.getInstance().create(Collections.emptyMap()); + GetAlertsResponse getAlertsResponse = new GetAlertsResponse(Collections.emptyList(), 0); + String expectedResponseStr = "Alerts=[]TotalAlerts=0"; + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(getAlertsResponse); + return null; + }).when(nodeClient).execute(any(ActionType.class), any(), any()); + + tool.run(nonEmptyParams, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(expectedResponseStr, responseCaptor.getValue()); + } + + @Test + public void testRunWithAlerts() throws Exception { + Tool tool = SearchAlertsTool.Factory.getInstance().create(Collections.emptyMap()); + Alert alert1 = new Alert( + "alert-id-1", + 1234, + 1, + "monitor-id", + "workflow-id", + "workflow-name", + "monitor-name", + 1234, + null, + "trigger-id", + "trigger-name", + Collections.emptyList(), + Collections.emptyList(), + Alert.State.ACKNOWLEDGED, + Instant.now(), + null, + null, + null, + null, + Collections.emptyList(), + "test-severity", + Collections.emptyList(), + null, + null, + Collections.emptyList(), + null + ); + Alert alert2 = new Alert( + "alert-id-2", + 1234, + 1, + "monitor-id", + "workflow-id", + "workflow-name", + "monitor-name", + 1234, + null, + "trigger-id", + "trigger-name", + Collections.emptyList(), + Collections.emptyList(), + Alert.State.ACKNOWLEDGED, + Instant.now(), + null, + null, + null, + null, + Collections.emptyList(), + "test-severity", + Collections.emptyList(), + null, + null, + Collections.emptyList(), + null + ); + List mockAlerts = List.of(alert1, alert2); + + GetAlertsResponse getAlertsResponse = new GetAlertsResponse(mockAlerts, mockAlerts.size()); + String expectedResponseStr = new StringBuilder() + .append("Alerts=[") + .append(alert1.toString()) + .append(alert2.toString()) + .append("]TotalAlerts=2") + .toString(); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(getAlertsResponse); + return null; + }).when(nodeClient).execute(any(ActionType.class), any(), any()); + + tool.run(nonEmptyParams, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(expectedResponseStr, responseCaptor.getValue()); + } + + @Test + public void testParseParams() throws Exception { + Tool tool = SearchAlertsTool.Factory.getInstance().create(Collections.emptyMap()); + Map validParams = new HashMap(); + validParams.put("sortOrder", "asc"); + validParams.put("sortString", "foo.bar"); + validParams.put("size", "10"); + validParams.put("startIndex", "0"); + validParams.put("searchString", "foo"); + validParams.put("severityLevel", "ALL"); + validParams.put("alertState", "ALL"); + validParams.put("monitorId", "foo"); + validParams.put("alertIndex", "foo"); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + assertDoesNotThrow(() -> tool.run(validParams, listener)); + assertDoesNotThrow(() -> tool.run(Map.of("monitorIds", "[]"), listener)); + assertDoesNotThrow(() -> tool.run(Map.of("monitorIds", "[foo]"), listener)); + assertDoesNotThrow(() -> tool.run(Map.of("workflowIds", "[]"), listener)); + assertDoesNotThrow(() -> tool.run(Map.of("workflowIds", "[foo]"), listener)); + assertDoesNotThrow(() -> tool.run(Map.of("alertIds", "[]"), listener)); + assertDoesNotThrow(() -> tool.run(Map.of("alertIds", "[foo]"), listener)); + } + + @Test + public void testValidate() { + Tool tool = SearchAlertsTool.Factory.getInstance().create(Collections.emptyMap()); + assertEquals(SearchAlertsTool.TYPE, tool.getType()); + assertTrue(tool.validate(emptyParams)); + assertTrue(tool.validate(nonEmptyParams)); + assertTrue(tool.validate(nullParams)); + } +} diff --git a/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java b/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java new file mode 100644 index 00000000..e0b04336 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/SearchAnomalyDetectorsToolTests.java @@ -0,0 +1,481 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.transport.GetAnomalyDetectorAction; +import org.opensearch.ad.transport.GetAnomalyDetectorResponse; +import org.opensearch.ad.transport.SearchAnomalyDetectorAction; +import org.opensearch.agent.TestHelpers; +import org.opensearch.agent.tools.utils.ToolConstants.DetectorStateString; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.search.SearchHit; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; + +public class SearchAnomalyDetectorsToolTests { + @Mock + private NamedWriteableRegistry namedWriteableRegistry; + @Mock + private NodeClient nodeClient; + + private Map nullParams; + private Map emptyParams; + private Map nonEmptyParams; + + private AnomalyDetector testDetector; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + SearchAnomalyDetectorsTool.Factory.getInstance().init(nodeClient, namedWriteableRegistry); + + nullParams = null; + emptyParams = Collections.emptyMap(); + nonEmptyParams = Map.of("detectorName", "foo"); + + testDetector = new AnomalyDetector( + "foo-id", + 1L, + "foo-name", + "foo-description", + "foo-time-field", + new ArrayList(Arrays.asList("foo-index")), + Collections.emptyList(), + null, + new IntervalTimeConfiguration(5, ChronoUnit.MINUTES), + null, + 1, + Collections.emptyMap(), + 1, + Instant.now(), + Collections.emptyList(), + null, + null, + null, + null, + null, + null, + null, + null, + null, + null + ); + } + + @Test + public void testRunWithNoDetectors() throws Exception { + Tool tool = SearchAnomalyDetectorsTool.Factory.getInstance().create(Collections.emptyMap()); + SearchResponse getDetectorsResponse = TestHelpers.generateSearchResponse(new SearchHit[0]); + String expectedResponseStr = String.format(Locale.getDefault(), "AnomalyDetectors=[]TotalAnomalyDetectors=0"); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(getDetectorsResponse); + return null; + }).when(nodeClient).execute(any(ActionType.class), any(), any()); + + tool.run(emptyParams, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(expectedResponseStr, responseCaptor.getValue()); + } + + @Test + public void testRunWithSingleAnomalyDetector() throws Exception { + final String detectorName = "detector-1"; + final String detectorId = "detector-1-id"; + Tool tool = SearchAnomalyDetectorsTool.Factory.getInstance().create(Collections.emptyMap()); + + XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent()); + content.startObject(); + content.field("name", testDetector.getName()); + content.field("detector_type", testDetector.getDetectorType()); + content.field("description", testDetector.getDescription()); + content.field("indices", testDetector.getIndices().get(0)); + content.field("last_update_time", testDetector.getLastUpdateTime().toEpochMilli()); + content.endObject(); + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(0, testDetector.getId(), null, null).sourceRef(BytesReference.bytes(content)); + SearchResponse getDetectorsResponse = TestHelpers.generateSearchResponse(hits); + String expectedResponseStr = getExpectedResponseString(testDetector); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(getDetectorsResponse); + return null; + }).when(nodeClient).execute(any(ActionType.class), any(), any()); + + tool.run(emptyParams, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(expectedResponseStr, responseCaptor.getValue()); + } + + @Test + public void testRunWithRunningDetectorTrue() throws Exception { + final String detectorName = "detector-1"; + final String detectorId = "detector-1-id"; + Tool tool = SearchAnomalyDetectorsTool.Factory.getInstance().create(Collections.emptyMap()); + + // Generate mock values and responses + SearchHit[] hits = new SearchHit[1]; + hits[0] = TestHelpers.generateSearchDetectorHit(detectorName, detectorId); + SearchResponse getDetectorsResponse = TestHelpers.generateSearchResponse(hits); + GetAnomalyDetectorResponse getDetectorProfileResponse = TestHelpers + .generateGetAnomalyDetectorResponses(new String[] { detectorName }, new String[] { DetectorStateString.Running.name() }); + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + mockProfileApiCalls(getDetectorsResponse, getDetectorProfileResponse); + + tool.run(Map.of("running", "true"), listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + String response = responseCaptor.getValue(); + assertTrue(response.contains(String.format(Locale.ROOT, "id=%s", detectorId))); + assertTrue(response.contains(String.format(Locale.ROOT, "name=%s", detectorName))); + assertTrue(response.contains(String.format(Locale.ROOT, "TotalAnomalyDetectors=%d", hits.length))); + } + + @Test + public void testRunWithRunningDetectorFalse() throws Exception { + final String detectorName = "detector-1"; + final String detectorId = "detector-1-id"; + Tool tool = SearchAnomalyDetectorsTool.Factory.getInstance().create(Collections.emptyMap()); + + // Generate mock values and responses + SearchHit[] hits = new SearchHit[1]; + hits[0] = TestHelpers.generateSearchDetectorHit(detectorName, detectorId); + SearchResponse getDetectorsResponse = TestHelpers.generateSearchResponse(hits); + GetAnomalyDetectorResponse getDetectorProfileResponse = TestHelpers + .generateGetAnomalyDetectorResponses(new String[] { detectorName }, new String[] { DetectorStateString.Running.name() }); + String expectedResponseStr = "AnomalyDetectors=[]TotalAnomalyDetectors=0"; + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + mockProfileApiCalls(getDetectorsResponse, getDetectorProfileResponse); + + tool.run(Map.of("running", "false"), listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(expectedResponseStr, responseCaptor.getValue()); + } + + @Test + public void testRunWithRunningDetectorUndefined() throws Exception { + final String detectorName = "detector-1"; + final String detectorId = "detector-1-id"; + Tool tool = SearchAnomalyDetectorsTool.Factory.getInstance().create(Collections.emptyMap()); + + // Generate mock values and responses + SearchHit[] hits = new SearchHit[1]; + hits[0] = TestHelpers.generateSearchDetectorHit(detectorName, detectorId); + SearchResponse getDetectorsResponse = TestHelpers.generateSearchResponse(hits); + GetAnomalyDetectorResponse getDetectorProfileResponse = TestHelpers + .generateGetAnomalyDetectorResponses(new String[] { detectorName }, new String[] { DetectorStateString.Running.name() }); + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + mockProfileApiCalls(getDetectorsResponse, getDetectorProfileResponse); + + tool.run(Map.of("foo", "bar"), listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + String response = responseCaptor.getValue(); + assertTrue(response.contains(String.format(Locale.ROOT, "id=%s", detectorId))); + assertTrue(response.contains(String.format(Locale.ROOT, "name=%s", detectorName))); + assertTrue(response.contains(String.format(Locale.ROOT, "TotalAnomalyDetectors=%d", hits.length))); + } + + @Test + public void testRunWithNullRealtimeTask() throws Exception { + final String detectorName = "detector-1"; + final String detectorId = "detector-1-id"; + Tool tool = SearchAnomalyDetectorsTool.Factory.getInstance().create(Collections.emptyMap()); + + // Generate mock values and responses + SearchHit[] hits = new SearchHit[1]; + hits[0] = TestHelpers.generateSearchDetectorHit(detectorName, detectorId); + SearchResponse getDetectorsResponse = TestHelpers.generateSearchResponse(hits); + GetAnomalyDetectorResponse getDetectorProfileResponse = TestHelpers + .generateGetAnomalyDetectorResponses(new String[] { detectorName }, new String[] { DetectorStateString.Running.name() }); + // Overriding the mocked response to realtime task and setting to null. This occurs when + // a detector is created but is never started. + when(getDetectorProfileResponse.getRealtimeAdTask()).thenReturn(null); + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + mockProfileApiCalls(getDetectorsResponse, getDetectorProfileResponse); + + tool.run(Map.of("running", "false"), listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + String response = responseCaptor.getValue(); + assertTrue(response.contains(String.format(Locale.ROOT, "id=%s", detectorId))); + assertTrue(response.contains(String.format(Locale.ROOT, "name=%s", detectorName))); + assertTrue(response.contains(String.format(Locale.ROOT, "TotalAnomalyDetectors=%d", hits.length))); + } + + @Test + public void testRunWithTaskStateCreated() throws Exception { + final String detectorName = "detector-1"; + final String detectorId = "detector-1-id"; + Tool tool = SearchAnomalyDetectorsTool.Factory.getInstance().create(Collections.emptyMap()); + + // Generate mock values and responses + SearchHit[] hits = new SearchHit[1]; + hits[0] = TestHelpers.generateSearchDetectorHit(detectorName, detectorId); + SearchResponse getDetectorsResponse = TestHelpers.generateSearchResponse(hits); + GetAnomalyDetectorResponse getDetectorProfileResponse = TestHelpers + .generateGetAnomalyDetectorResponses(new String[] { detectorName }, new String[] { DetectorStateString.Running.name() }); + // Overriding the mocked response to set realtime task state to CREATED + when(getDetectorProfileResponse.getRealtimeAdTask().getState()).thenReturn("CREATED"); + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + mockProfileApiCalls(getDetectorsResponse, getDetectorProfileResponse); + + tool.run(Map.of("running", "true"), listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + String response = responseCaptor.getValue(); + assertTrue(response.contains(String.format(Locale.ROOT, "id=%s", detectorId))); + assertTrue(response.contains(String.format(Locale.ROOT, "name=%s", detectorName))); + assertTrue(response.contains(String.format(Locale.ROOT, "TotalAnomalyDetectors=%d", hits.length))); + } + + @Test + public void testRunWithTaskStateVariousFailed() throws Exception { + final String detectorName1 = "detector-1"; + final String detectorId1 = "detector-1-id"; + final String detectorName2 = "detector-2"; + final String detectorId2 = "detector-2-id"; + final String detectorName3 = "detector-3"; + final String detectorId3 = "detector-3-id"; + Tool tool = SearchAnomalyDetectorsTool.Factory.getInstance().create(Collections.emptyMap()); + + // Generate mock values and responses + SearchHit[] hits = new SearchHit[3]; + hits[0] = TestHelpers.generateSearchDetectorHit(detectorName1, detectorId1); + hits[1] = TestHelpers.generateSearchDetectorHit(detectorName2, detectorId2); + hits[2] = TestHelpers.generateSearchDetectorHit(detectorName3, detectorId3); + SearchResponse getDetectorsResponse = TestHelpers.generateSearchResponse(hits); + GetAnomalyDetectorResponse getDetectorProfileResponse = TestHelpers + .generateGetAnomalyDetectorResponses( + new String[] { detectorName1, detectorName2, detectorName3 }, + new String[] { "INIT_FAILURE", "UNEXPECTED_FAILURE", "FAILED" } + ); + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + mockProfileApiCalls(getDetectorsResponse, getDetectorProfileResponse); + + tool.run(Map.of("failed", "true"), listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + String response = responseCaptor.getValue(); + assertTrue(response.contains(String.format(Locale.ROOT, "id=%s", detectorId1))); + assertTrue(response.contains(String.format(Locale.ROOT, "name=%s", detectorName1))); + assertTrue(response.contains(String.format(Locale.ROOT, "id=%s", detectorId2))); + assertTrue(response.contains(String.format(Locale.ROOT, "name=%s", detectorName2))); + assertTrue(response.contains(String.format(Locale.ROOT, "id=%s", detectorId3))); + assertTrue(response.contains(String.format(Locale.ROOT, "name=%s", detectorName3))); + assertTrue(response.contains(String.format(Locale.ROOT, "TotalAnomalyDetectors=%d", hits.length))); + } + + @Test + public void testRunWithCombinedDetectorStatesTrue() throws Exception { + final String detectorName1 = "detector-1"; + final String detectorId1 = "detector-1-id"; + final String detectorName2 = "detector-2"; + final String detectorId2 = "detector-2-id"; + final String detectorName3 = "detector-3"; + final String detectorId3 = "detector-3-id"; + Tool tool = SearchAnomalyDetectorsTool.Factory.getInstance().create(Collections.emptyMap()); + + // Generate mock values and responses + SearchHit[] hits = new SearchHit[3]; + hits[0] = TestHelpers.generateSearchDetectorHit(detectorName1, detectorId1); + hits[1] = TestHelpers.generateSearchDetectorHit(detectorName2, detectorId2); + hits[2] = TestHelpers.generateSearchDetectorHit(detectorName3, detectorId3); + SearchResponse getDetectorsResponse = TestHelpers.generateSearchResponse(hits); + GetAnomalyDetectorResponse getDetectorProfileResponse = TestHelpers + .generateGetAnomalyDetectorResponses( + new String[] { detectorName1, detectorName2, detectorName3 }, + new String[] { DetectorStateString.Running.name(), DetectorStateString.Disabled.name(), DetectorStateString.Failed.name() } + ); + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + mockProfileApiCalls(getDetectorsResponse, getDetectorProfileResponse); + + tool.run(Map.of("running", "true", "failed", "true"), listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + String response = responseCaptor.getValue(); + assertTrue(response.contains(String.format(Locale.ROOT, "id=%s", detectorId1))); + assertTrue(response.contains(String.format(Locale.ROOT, "name=%s", detectorName1))); + assertTrue(response.contains(String.format(Locale.ROOT, "id=%s", detectorId3))); + assertTrue(response.contains(String.format(Locale.ROOT, "name=%s", detectorName3))); + assertTrue(response.contains(String.format(Locale.ROOT, "TotalAnomalyDetectors=%d", 2))); + } + + @Test + public void testRunWithCombinedDetectorStatesFalse() throws Exception { + final String detectorName1 = "detector-1"; + final String detectorId1 = "detector-1-id"; + final String detectorName2 = "detector-2"; + final String detectorId2 = "detector-2-id"; + final String detectorName3 = "detector-3"; + final String detectorId3 = "detector-3-id"; + Tool tool = SearchAnomalyDetectorsTool.Factory.getInstance().create(Collections.emptyMap()); + + // Generate mock values and responses + SearchHit[] hits = new SearchHit[3]; + hits[0] = TestHelpers.generateSearchDetectorHit(detectorName1, detectorId1); + hits[1] = TestHelpers.generateSearchDetectorHit(detectorName2, detectorId2); + hits[2] = TestHelpers.generateSearchDetectorHit(detectorName3, detectorId3); + SearchResponse getDetectorsResponse = TestHelpers.generateSearchResponse(hits); + GetAnomalyDetectorResponse getDetectorProfileResponse = TestHelpers + .generateGetAnomalyDetectorResponses( + new String[] { detectorName1, detectorName2, detectorName3 }, + new String[] { DetectorStateString.Running.name(), DetectorStateString.Disabled.name(), DetectorStateString.Failed.name() } + ); + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + mockProfileApiCalls(getDetectorsResponse, getDetectorProfileResponse); + + tool.run(Map.of("running", "false", "failed", "false"), listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertTrue(responseCaptor.getValue().contains("TotalAnomalyDetectors=1")); + } + + @Test + public void testRunWithCombinedDetectorStatesMixed() throws Exception { + final String detectorName1 = "detector-1"; + final String detectorId1 = "detector-1-id"; + final String detectorName2 = "detector-2"; + final String detectorId2 = "detector-2-id"; + final String detectorName3 = "detector-3"; + final String detectorId3 = "detector-3-id"; + Tool tool = SearchAnomalyDetectorsTool.Factory.getInstance().create(Collections.emptyMap()); + + // Generate mock values and responses + SearchHit[] hits = new SearchHit[3]; + hits[0] = TestHelpers.generateSearchDetectorHit(detectorName1, detectorId1); + hits[1] = TestHelpers.generateSearchDetectorHit(detectorName2, detectorId2); + hits[2] = TestHelpers.generateSearchDetectorHit(detectorName3, detectorId3); + SearchResponse getDetectorsResponse = TestHelpers.generateSearchResponse(hits); + GetAnomalyDetectorResponse getDetectorProfileResponse = TestHelpers + .generateGetAnomalyDetectorResponses( + new String[] { detectorName1, detectorName2, detectorName3 }, + new String[] { DetectorStateString.Running.name(), DetectorStateString.Disabled.name(), DetectorStateString.Failed.name() } + ); + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + mockProfileApiCalls(getDetectorsResponse, getDetectorProfileResponse); + + tool.run(Map.of("running", "true", "failed", "false"), listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + String response = responseCaptor.getValue(); + assertTrue(response.contains(String.format(Locale.ROOT, "id=%s", detectorId1))); + assertTrue(response.contains(String.format(Locale.ROOT, "name=%s", detectorName1))); + assertTrue(response.contains(String.format(Locale.ROOT, "TotalAnomalyDetectors=%d", 1))); + } + + @Test + public void testParseParams() throws Exception { + Tool tool = SearchAnomalyDetectorsTool.Factory.getInstance().create(Collections.emptyMap()); + Map validParams = new HashMap(); + validParams.put("detectorName", "foo"); + validParams.put("indices", "foo"); + validParams.put("highCardinality", "false"); + validParams.put("lastUpdateTime", "1234"); + validParams.put("sortOrder", "foo"); + validParams.put("size", "10"); + validParams.put("startIndex", "0"); + validParams.put("running", "false"); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + assertDoesNotThrow(() -> tool.run(validParams, listener)); + assertDoesNotThrow(() -> tool.run(Map.of("detectorNamePattern", "foo*"), listener)); + assertDoesNotThrow(() -> tool.run(Map.of("sortOrder", "AsC"), listener)); + } + + @Test + public void testValidate() { + Tool tool = SearchAnomalyDetectorsTool.Factory.getInstance().create(Collections.emptyMap()); + assertEquals(SearchAnomalyDetectorsTool.TYPE, tool.getType()); + assertTrue(tool.validate(emptyParams)); + assertTrue(tool.validate(nonEmptyParams)); + assertTrue(tool.validate(nullParams)); + } + + private void mockProfileApiCalls(SearchResponse getDetectorsResponse, GetAnomalyDetectorResponse getDetectorProfileResponse) { + // Mock return from initial search call + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(getDetectorsResponse); + return null; + }).when(nodeClient).execute(any(SearchAnomalyDetectorAction.class), any(), any()); + + // Mock return from secondary detector profile call + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(getDetectorProfileResponse); + return null; + }).when(nodeClient).execute(any(GetAnomalyDetectorAction.class), any(), any()); + } + + private String getExpectedResponseString(AnomalyDetector testDetector) { + return String + .format( + "AnomalyDetectors=[{id=%s,name=%s,type=%s,description=%s,index=%s,lastUpdateTime=%d}]TotalAnomalyDetectors=%d", + testDetector.getId(), + testDetector.getName(), + testDetector.getDetectorType(), + testDetector.getDescription(), + testDetector.getIndices().get(0), + testDetector.getLastUpdateTime().toEpochMilli(), + 1 + ); + + } +} diff --git a/src/test/java/org/opensearch/agent/tools/SearchAnomalyResultsToolTests.java b/src/test/java/org/opensearch/agent/tools/SearchAnomalyResultsToolTests.java new file mode 100644 index 00000000..ba11702f --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/SearchAnomalyResultsToolTests.java @@ -0,0 +1,209 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; + +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.agent.tools.utils.ToolConstants; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.Aggregations; + +public class SearchAnomalyResultsToolTests { + @Mock + private NamedWriteableRegistry namedWriteableRegistry; + @Mock + private NodeClient nodeClient; + + private Map nullParams; + private Map emptyParams; + private Map nonEmptyParams; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + SearchAnomalyResultsTool.Factory.getInstance().init(nodeClient, namedWriteableRegistry); + + nullParams = null; + emptyParams = Collections.emptyMap(); + nonEmptyParams = Map.of("detectorId", "foo"); + } + + @Test + public void testParseParams() throws Exception { + Tool tool = SearchAnomalyResultsTool.Factory.getInstance().create(Collections.emptyMap()); + Map validParams = new HashMap(); + validParams.put("detectorId", "foo"); + validParams.put("realTime", "true"); + validParams.put("anomalyGradethreshold", "-1"); + validParams.put("dataStartTime", "1234"); + validParams.put("dataEndTime", "5678"); + validParams.put("sortOrder", "AsC"); + validParams.put("sortString", "foo.bar"); + validParams.put("size", "10"); + validParams.put("startIndex", "0"); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + assertDoesNotThrow(() -> tool.run(validParams, listener)); + } + + @Test + public void testRunWithInvalidAnomalyGradeParam() throws Exception { + Tool tool = SearchAnomalyResultsTool.Factory.getInstance().create(Collections.emptyMap()); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + assertThrows(NumberFormatException.class, () -> tool.run(Map.of("anomalyGradeThreshold", "foo"), listener)); + } + + @Test + public void testRunWithNoResults() throws Exception { + Tool tool = SearchAnomalyResultsTool.Factory.getInstance().create(Collections.emptyMap()); + + SearchHit[] hits = new SearchHit[0]; + + TotalHits totalHits = new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO); + + SearchResponse getResultsResponse = new SearchResponse( + new SearchResponseSections(new SearchHits(hits, totalHits, 0), new Aggregations(new ArrayList<>()), null, false, null, null, 0), + null, + 0, + 0, + 0, + 0, + null, + null + ); + String expectedResponseStr = String.format(Locale.getDefault(), "AnomalyResults=[]TotalAnomalyResults=%d", hits.length); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(getResultsResponse); + return null; + }).when(nodeClient).execute(any(ActionType.class), any(), any()); + + tool.run(emptyParams, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(expectedResponseStr, responseCaptor.getValue()); + } + + @Test + public void testRunWithSingleResult() throws Exception { + final String detectorId = "detector-1-id"; + final double anomalyGrade = 0.5; + final double confidence = 0.9; + Tool tool = SearchAnomalyResultsTool.Factory.getInstance().create(Collections.emptyMap()); + + XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent()); + content.startObject(); + content.field("detector_id", detectorId); + content.field("anomaly_grade", anomalyGrade); + content.field("confidence", confidence); + content.endObject(); + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(0, detectorId, null, null).sourceRef(BytesReference.bytes(content)); + + TotalHits totalHits = new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO); + + SearchResponse getResultsResponse = new SearchResponse( + new SearchResponseSections(new SearchHits(hits, totalHits, 0), new Aggregations(new ArrayList<>()), null, false, null, null, 0), + null, + 0, + 0, + 0, + 0, + null, + null + ); + String expectedResponseStr = String + .format( + "AnomalyResults=[{detectorId=%s,grade=%2.1f,confidence=%2.1f}]TotalAnomalyResults=%d", + detectorId, + anomalyGrade, + confidence, + hits.length + ); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(getResultsResponse); + return null; + }).when(nodeClient).execute(any(ActionType.class), any(), any()); + + tool.run(emptyParams, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(expectedResponseStr, responseCaptor.getValue()); + } + + @Test + public void testDefaultIndexPatternIsSet() throws Exception { + Tool tool = SearchAnomalyResultsTool.Factory.getInstance().create(Collections.emptyMap()); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + + doAnswer((invocation) -> { + SearchRequest generatedRequest = invocation.getArgument(1); + String[] indices = generatedRequest.indices(); + assertNotNull(indices); + assertEquals(1, indices.length); + assertEquals(ToolConstants.AD_RESULTS_INDEX_PATTERN, indices[0]); + return null; + }).when(nodeClient).execute(any(ActionType.class), any(), any()); + + tool.run(emptyParams, listener); + } + + @Test + public void testValidate() { + Tool tool = SearchAnomalyResultsTool.Factory.getInstance().create(Collections.emptyMap()); + assertEquals(SearchAnomalyResultsTool.TYPE, tool.getType()); + assertTrue(tool.validate(emptyParams)); + assertTrue(tool.validate(nonEmptyParams)); + assertTrue(tool.validate(nullParams)); + } +} diff --git a/src/test/java/org/opensearch/agent/tools/SearchMonitorsToolTests.java b/src/test/java/org/opensearch/agent/tools/SearchMonitorsToolTests.java new file mode 100644 index 00000000..250ce5a2 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/SearchMonitorsToolTests.java @@ -0,0 +1,272 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.time.Instant; +import java.time.ZoneId; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.client.AdminClient; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.alerting.model.CronSchedule; +import org.opensearch.commons.alerting.model.DataSources; +import org.opensearch.commons.alerting.model.Monitor; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.Aggregations; + +public class SearchMonitorsToolTests { + @Mock + private NodeClient nodeClient; + @Mock + private AdminClient adminClient; + @Mock + private IndicesAdminClient indicesAdminClient; + @Mock + private ClusterAdminClient clusterAdminClient; + + private Map nullParams; + private Map emptyParams; + private Map nonEmptyParams; + private Map monitorIdParams; + + private Monitor testMonitor; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + SearchMonitorsTool.Factory.getInstance().init(nodeClient); + + nullParams = null; + emptyParams = Collections.emptyMap(); + nonEmptyParams = Map.of("monitorName", "foo"); + monitorIdParams = Map.of("monitorId", "foo"); + testMonitor = new Monitor( + "monitor-1-id", + 0L, + "monitor-1", + true, + new CronSchedule("31 * * * *", ZoneId.of("Asia/Kolkata"), null), + Instant.now(), + Instant.now(), + Monitor.MonitorType.QUERY_LEVEL_MONITOR.getValue(), + new User("test-user", Collections.emptyList(), Collections.emptyList(), Collections.emptyList()), + 0, + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyMap(), + new DataSources(), + "" + ); + } + + @Test + public void testRunWithNoMonitors() throws Exception { + Tool tool = SearchMonitorsTool.Factory.getInstance().create(Collections.emptyMap()); + SearchResponse searchMonitorsResponse = getEmptySearchMonitorsResponse(); + String expectedResponseStr = "Monitors=[]TotalMonitors=0"; + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(searchMonitorsResponse); + return null; + }).when(nodeClient).execute(any(ActionType.class), any(), any()); + + tool.run(emptyParams, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(expectedResponseStr, responseCaptor.getValue()); + } + + @Test + public void testRunWithMonitorId() throws Exception { + Tool tool = SearchMonitorsTool.Factory.getInstance().create(Collections.emptyMap()); + + SearchResponse searchMonitorsResponse = getSearchMonitorsResponse(testMonitor); + String expectedResponseStr = getExpectedResponseString(testMonitor); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(searchMonitorsResponse); + return null; + }).when(nodeClient).execute(any(ActionType.class), any(), any()); + + tool.run(monitorIdParams, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(expectedResponseStr, responseCaptor.getValue()); + } + + @Test + public void testRunWithMonitorIdNotFound() throws Exception { + Tool tool = SearchMonitorsTool.Factory.getInstance().create(Collections.emptyMap()); + + SearchResponse searchMonitorsResponse = getEmptySearchMonitorsResponse(); + String expectedResponseStr = "Monitors=[]TotalMonitors=0"; + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(searchMonitorsResponse); + return null; + }).when(nodeClient).execute(any(ActionType.class), any(), any()); + + tool.run(monitorIdParams, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(expectedResponseStr, responseCaptor.getValue()); + } + + @Test + public void testRunWithSingleMonitor() throws Exception { + Tool tool = SearchMonitorsTool.Factory.getInstance().create(Collections.emptyMap()); + + SearchResponse searchMonitorsResponse = getSearchMonitorsResponse(testMonitor); + String expectedResponseStr = getExpectedResponseString(testMonitor); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + + doAnswer((invocation) -> { + ActionListener responseListener = invocation.getArgument(2); + responseListener.onResponse(searchMonitorsResponse); + return null; + }).when(nodeClient).execute(any(ActionType.class), any(), any()); + + tool.run(emptyParams, listener); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(String.class); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(expectedResponseStr, responseCaptor.getValue()); + } + + @Test + public void testParseParams() throws Exception { + Tool tool = SearchMonitorsTool.Factory.getInstance().create(Collections.emptyMap()); + Map validParams = new HashMap(); + validParams.put("monitorName", "foo"); + validParams.put("enabled", "true"); + validParams.put("hasTriggers", "true"); + validParams.put("indices", "bar"); + validParams.put("sortOrder", "ASC"); + validParams.put("sortString", "baz"); + validParams.put("size", "10"); + validParams.put("startIndex", "0"); + + @SuppressWarnings("unchecked") + ActionListener listener = Mockito.mock(ActionListener.class); + assertDoesNotThrow(() -> tool.run(validParams, listener)); + assertDoesNotThrow(() -> tool.run(Map.of("hasTriggers", "false"), listener)); + assertDoesNotThrow(() -> tool.run(Map.of("monitorNamePattern", "foo*"), listener)); + assertDoesNotThrow(() -> tool.run(Map.of("detectorId", "foo"), listener)); + assertDoesNotThrow(() -> tool.run(Map.of("sortOrder", "AsC"), listener)); + } + + @Test + public void testValidate() { + Tool tool = SearchMonitorsTool.Factory.getInstance().create(Collections.emptyMap()); + assertEquals(SearchMonitorsTool.TYPE, tool.getType()); + assertTrue(tool.validate(emptyParams)); + assertTrue(tool.validate(nonEmptyParams)); + assertTrue(tool.validate(monitorIdParams)); + assertTrue(tool.validate(nullParams)); + } + + private SearchResponse getSearchMonitorsResponse(Monitor monitor) throws Exception { + XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent()); + content + .startObject() + .startObject("monitor") + .field("name", monitor.getName()) + .field("monitor_type", monitor.getType()) + .field("enabled", Boolean.toString(monitor.getEnabled())) + .field("enabled_time", Long.toString(monitor.getEnabledTime().toEpochMilli())) + .field("last_update_time", Long.toString(monitor.getLastUpdateTime().toEpochMilli())) + .endObject() + .endObject(); + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(0, monitor.getId(), null, null).sourceRef(BytesReference.bytes(content)); + + TotalHits totalHits = new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO); + + return new SearchResponse( + new SearchResponseSections(new SearchHits(hits, totalHits, 0), new Aggregations(new ArrayList<>()), null, false, null, null, 0), + null, + 0, + 0, + 0, + 0, + null, + null + ); + } + + private SearchResponse getEmptySearchMonitorsResponse() throws Exception { + SearchHit[] hits = new SearchHit[0]; + TotalHits totalHits = new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO); + return new SearchResponse( + new SearchResponseSections(new SearchHits(hits, totalHits, 0), new Aggregations(new ArrayList<>()), null, false, null, null, 0), + null, + 0, + 0, + 0, + 0, + null, + null + ); + } + + private String getExpectedResponseString(Monitor testMonitor) { + return String + .format( + "Monitors=[{id=%s,name=%s,type=%s,enabled=%s,enabledTime=%d,lastUpdateTime=%d}]TotalMonitors=%d", + testMonitor.getId(), + testMonitor.getName(), + testMonitor.getType(), + testMonitor.getEnabled(), + testMonitor.getEnabledTime().toEpochMilli(), + testMonitor.getLastUpdateTime().toEpochMilli(), + 1 + ); + + } +} diff --git a/src/test/java/org/opensearch/agent/tools/ToolHelperTests.java b/src/test/java/org/opensearch/agent/tools/ToolHelperTests.java new file mode 100644 index 00000000..5b6dfa7f --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/ToolHelperTests.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Test; +import org.opensearch.agent.tools.utils.ToolHelper; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class ToolHelperTests { + @Test + public void TestExtractFieldNamesTypes() { + Map indexMappings = Map + .of( + "response", + Map.of("type", "integer"), + "responseLatency", + Map.of("type", "float"), + "date", + Map.of("type", "date"), + "objectA", + Map.of("type", "object", "properties", Map.of("subA", Map.of("type", "keyword"))), + "objectB", + Map.of("properties", Map.of("subB", Map.of("type", "keyword"))), + "textC", + Map.of("type", "text", "fields", Map.of("subC", Map.of("type", "keyword"))), + "aliasD", + Map.of("type", "alias", "path", "date") + ); + Map result = new HashMap<>(); + ToolHelper.extractFieldNamesTypes(indexMappings, result, "", true); + assertMapEquals( + result, + Map + .of( + "response", + "integer", + "responseLatency", + "float", + "date", + "date", + "objectA.subA", + "keyword", + "objectB.subB", + "keyword", + "textC", + "text", + "textC.subC", + "keyword" + ) + ); + + Map result1 = new HashMap<>(); + ToolHelper.extractFieldNamesTypes(indexMappings, result1, "", false); + assertMapEquals( + result1, + Map + .of( + "response", + "integer", + "responseLatency", + "float", + "date", + "date", + "objectA.subA", + "keyword", + "objectB.subB", + "keyword", + "textC", + "text" + ) + ); + } + + private void assertMapEquals(Map expected, Map actual) { + assertEquals(expected.size(), actual.size()); + for (Map.Entry entry : expected.entrySet()) { + assertEquals(entry.getValue(), actual.get(entry.getKey())); + } + } +} diff --git a/src/test/java/org/opensearch/agent/tools/VectorDBToolTests.java b/src/test/java/org/opensearch/agent/tools/VectorDBToolTests.java new file mode 100644 index 00000000..635724a7 --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/VectorDBToolTests.java @@ -0,0 +1,135 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; + +import com.google.gson.JsonSyntaxException; + +import lombok.SneakyThrows; + +public class VectorDBToolTests { + public static final String TEST_QUERY_TEXT = "123fsd23134sdfouh"; + public static final String TEST_EMBEDDING_FIELD = "test embedding"; + public static final String TEST_MODEL_ID = "123fsd23134"; + public static final Integer TEST_K = 123; + public static final String TEST_NESTED_PATH = "nested_path"; + private Map params = new HashMap<>(); + + @Before + public void setup() { + params.put(VectorDBTool.INDEX_FIELD, AbstractRetrieverToolTests.TEST_INDEX); + params.put(VectorDBTool.EMBEDDING_FIELD, TEST_EMBEDDING_FIELD); + params.put(VectorDBTool.SOURCE_FIELD, gson.toJson(AbstractRetrieverToolTests.TEST_SOURCE_FIELDS)); + params.put(VectorDBTool.MODEL_ID_FIELD, TEST_MODEL_ID); + params.put(VectorDBTool.DOC_SIZE_FIELD, AbstractRetrieverToolTests.TEST_DOC_SIZE.toString()); + params.put(VectorDBTool.K_FIELD, TEST_K.toString()); + } + + @Test + @SneakyThrows + public void testCreateTool() { + VectorDBTool tool = VectorDBTool.Factory.getInstance().create(params); + assertEquals(AbstractRetrieverToolTests.TEST_INDEX, tool.getIndex()); + assertEquals(TEST_EMBEDDING_FIELD, tool.getEmbeddingField()); + assertEquals(AbstractRetrieverToolTests.TEST_SOURCE_FIELDS, tool.getSourceFields()); + assertEquals(TEST_MODEL_ID, tool.getModelId()); + assertEquals(AbstractRetrieverToolTests.TEST_DOC_SIZE, tool.getDocSize()); + assertEquals(TEST_K, tool.getK()); + assertEquals("VectorDBTool", tool.getType()); + assertEquals("VectorDBTool", tool.getName()); + assertEquals(VectorDBTool.DEFAULT_DESCRIPTION, VectorDBTool.Factory.getInstance().getDefaultDescription()); + } + + @Test + @SneakyThrows + public void testGetQueryBody() { + VectorDBTool tool = VectorDBTool.Factory.getInstance().create(params); + Map>>> queryBody = gson.fromJson(tool.getQueryBody(TEST_QUERY_TEXT), Map.class); + assertEquals("123fsd23134sdfouh", queryBody.get("query").get("neural").get("test embedding").get("query_text")); + assertEquals("123fsd23134", queryBody.get("query").get("neural").get("test embedding").get("model_id")); + assertEquals(123.0, queryBody.get("query").get("neural").get("test embedding").get("k")); + } + + @Test + @SneakyThrows + public void testGetQueryBodyWithNestedPath() { + params.put(VectorDBTool.NESTED_PATH_FIELD, TEST_NESTED_PATH); + VectorDBTool tool = VectorDBTool.Factory.getInstance().create(params); + Map>> nestedQueryBody = gson.fromJson(tool.getQueryBody(TEST_QUERY_TEXT), Map.class); + assertEquals("nested_path", nestedQueryBody.get("query").get("nested").get("path")); + assertEquals("max", nestedQueryBody.get("query").get("nested").get("score_mode")); + Map>> queryBody = (Map>>) nestedQueryBody + .get("query") + .get("nested") + .get("query"); + assertEquals("123fsd23134sdfouh", queryBody.get("neural").get("test embedding").get("query_text")); + assertEquals("123fsd23134", queryBody.get("neural").get("test embedding").get("model_id")); + } + + @Test + @SneakyThrows + public void testGetQueryBodyWithJsonObjectString() { + VectorDBTool tool = VectorDBTool.Factory.getInstance().create(params); + String jsonInput = gson.toJson(Map.of("hi", "a")); + Map>>> queryBody = gson.fromJson(tool.getQueryBody(jsonInput), Map.class); + assertEquals("{\"hi\":\"a\"}", queryBody.get("query").get("neural").get("test embedding").get("query_text")); + assertEquals("123fsd23134", queryBody.get("query").get("neural").get("test embedding").get("model_id")); + assertEquals(123.0, queryBody.get("query").get("neural").get("test embedding").get("k")); + } + + @Test + @SneakyThrows + public void testGetQueryBodyWithIllegalParams() { + Map illegalParams1 = new HashMap<>(params); + illegalParams1.remove(VectorDBTool.MODEL_ID_FIELD); + VectorDBTool tool1 = VectorDBTool.Factory.getInstance().create(illegalParams1); + Exception exception1 = assertThrows( + IllegalArgumentException.class, + () -> tool1.getQueryBody(AbstractRetrieverToolTests.TEST_QUERY) + ); + assertEquals("Parameter [embedding_field] and [model_id] can not be null or empty.", exception1.getMessage()); + + Map illegalParams2 = new HashMap<>(params); + illegalParams2.remove(VectorDBTool.EMBEDDING_FIELD); + VectorDBTool tool2 = VectorDBTool.Factory.getInstance().create(illegalParams2); + Exception exception2 = assertThrows( + IllegalArgumentException.class, + () -> tool2.getQueryBody(AbstractRetrieverToolTests.TEST_QUERY) + ); + assertEquals("Parameter [embedding_field] and [model_id] can not be null or empty.", exception2.getMessage()); + } + + @Test + @SneakyThrows + public void testCreateToolsParseParams() { + assertThrows(ClassCastException.class, () -> VectorDBTool.Factory.getInstance().create(Map.of(VectorDBTool.INDEX_FIELD, 123))); + + assertThrows(ClassCastException.class, () -> VectorDBTool.Factory.getInstance().create(Map.of(VectorDBTool.EMBEDDING_FIELD, 123))); + + assertThrows(ClassCastException.class, () -> VectorDBTool.Factory.getInstance().create(Map.of(VectorDBTool.MODEL_ID_FIELD, 123))); + + assertThrows( + ClassCastException.class, + () -> VectorDBTool.Factory.getInstance().create(Map.of(VectorDBTool.NESTED_PATH_FIELD, 123)) + ); + + assertThrows(JsonSyntaxException.class, () -> VectorDBTool.Factory.getInstance().create(Map.of(VectorDBTool.SOURCE_FIELD, "123"))); + + // although it will be parsed as integer, but the parameters value should always be String + assertThrows(ClassCastException.class, () -> VectorDBTool.Factory.getInstance().create(Map.of(VectorDBTool.DOC_SIZE_FIELD, 123))); + + assertThrows(ClassCastException.class, () -> VectorDBTool.Factory.getInstance().create(Map.of(VectorDBTool.K_FIELD, 123))); + } +} diff --git a/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java b/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java new file mode 100644 index 00000000..658a3fc7 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/BaseAgentToolsIT.java @@ -0,0 +1,391 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +import org.apache.commons.lang3.StringUtils; +import org.apache.http.Header; +import org.apache.http.HttpEntity; +import org.apache.http.HttpHeaders; +import org.apache.http.entity.ContentType; +import org.apache.http.entity.StringEntity; +import org.apache.http.message.BasicHeader; +import org.apache.http.util.EntityUtils; +import org.junit.Before; +import org.opensearch.client.*; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.input.execute.agent.AgentMLInput; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; + +import com.google.common.collect.ImmutableList; +import com.google.gson.Gson; + +import lombok.SneakyThrows; + +public abstract class BaseAgentToolsIT extends OpenSearchSecureRestTestCase { + public static final Gson gson = new Gson(); + private static final int MAX_TASK_RESULT_QUERY_TIME_IN_SECOND = 60 * 5; + private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000; + + /** + * Update cluster settings to run ml models + */ + @Before + public void updateClusterSettings() { + updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false); + // default threshold for native circuit breaker is 90, it may be not enough on test runner machine + updateClusterSettings("plugins.ml_commons.native_memory_threshold", 100); + updateClusterSettings("plugins.ml_commons.jvm_heap_memory_threshold", 100); + updateClusterSettings("plugins.ml_commons.allow_registering_model_via_url", true); + updateClusterSettings("plugins.ml_commons.agent_framework_enabled", true); + } + + @SneakyThrows + protected void updateClusterSettings(String settingKey, Object value) { + XContentBuilder builder = XContentFactory + .jsonBuilder() + .startObject() + .startObject("persistent") + .field(settingKey, value) + .endObject() + .endObject(); + Response response = makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + builder.toString(), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + @SneakyThrows + private Map parseResponseToMap(Response response) { + Map responseInMap = XContentHelper + .convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(response.getEntity()), false); + return responseInMap; + } + + @SneakyThrows + private Object parseFieldFromResponse(Response response, String field) { + assertNotNull(field); + Map map = parseResponseToMap(response); + Object result = map.get(field); + assertNotNull(result); + return result; + } + + protected String createConnector(String requestBody) { + Response response = makeRequest(client(), "POST", "/_plugins/_ml/connectors/_create", null, requestBody, null); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + return parseFieldFromResponse(response, MLModel.CONNECTOR_ID_FIELD).toString(); + } + + protected String registerModel(String requestBody) { + Response response = makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, requestBody, null); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + return parseFieldFromResponse(response, MLTask.TASK_ID_FIELD).toString(); + } + + protected String deployModel(String modelId) { + Response response = makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_deploy", null, (String) null, null); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + return parseFieldFromResponse(response, MLTask.TASK_ID_FIELD).toString(); + } + + protected String indexMonitor(String monitorAsJsonString) { + Response response = makeRequest(client(), "POST", "_plugins/_alerting/monitors", null, monitorAsJsonString, null); + + assertEquals(RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + return parseFieldFromResponse(response, "_id").toString(); + } + + protected void deleteMonitor(String monitorId) { + Response response = makeRequest(client(), "DELETE", "_plugins/_alerting/monitors/" + monitorId, null, (String) null, null); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + protected String indexDetector(String detectorAsJsonString) { + Response response = makeRequest(client(), "POST", "_plugins/_anomaly_detection/detectors", null, detectorAsJsonString, null); + + assertEquals(RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + return parseFieldFromResponse(response, "_id").toString(); + } + + protected void startDetector(String detectorId) { + Response response = makeRequest( + client(), + "POST", + "_plugins/_anomaly_detection/detectors/" + detectorId + "/_start", + null, + (String) null, + null + ); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + protected void stopDetector(String detectorId) { + Response response = makeRequest( + client(), + "POST", + "_plugins/_anomaly_detection/detectors/" + detectorId + "/_stop", + null, + (String) null, + null + ); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + protected void deleteDetector(String detectorId) { + Response response = makeRequest( + client(), + "DELETE", + "_plugins/_anomaly_detection/detectors/" + detectorId, + null, + (String) null, + null + ); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + @SneakyThrows + protected Map waitResponseMeetingCondition( + String method, + String endpoint, + String jsonEntity, + Predicate> condition + ) { + for (int i = 0; i < MAX_TASK_RESULT_QUERY_TIME_IN_SECOND; i++) { + Response response = makeRequest(client(), method, endpoint, null, jsonEntity, null); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + Map responseInMap = parseResponseToMap(response); + if (condition.test(responseInMap)) { + return responseInMap; + } + logger.error(String.format(Locale.ROOT, "The %s-th response: %s", i, responseInMap.toString())); + Thread.sleep(DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND); + } + fail("The response failed to meet condition after " + MAX_TASK_RESULT_QUERY_TIME_IN_SECOND + " seconds."); + return null; + } + + @SneakyThrows + protected Map waitTaskComplete(String taskId) { + Predicate> condition = responseInMap -> { + String state = responseInMap.get(MLTask.STATE_FIELD).toString(); + return state.equals(MLTaskState.COMPLETED.toString()); + }; + return waitResponseMeetingCondition("GET", "/_plugins/_ml/tasks/" + taskId, (String) null, condition); + } + + // Register the model then deploy it. Returns the model_id until the model is deployed + protected String registerModelThenDeploy(String requestBody) { + String registerModelTaskId = registerModel(requestBody); + Map registerTaskResponseInMap = waitTaskComplete(registerModelTaskId); + String modelId = registerTaskResponseInMap.get(MLTask.MODEL_ID_FIELD).toString(); + String deployModelTaskId = deployModel(modelId); + waitTaskComplete(deployModelTaskId); + return modelId; + } + + @SneakyThrows + private void waitModelUndeployed(String modelId) { + Predicate> condition = responseInMap -> { + String state = responseInMap.get(MLModel.MODEL_STATE_FIELD).toString(); + return !state.equals(MLModelState.DEPLOYED.toString()) + && !state.equals(MLModelState.DEPLOYING.toString()) + && !state.equals(MLModelState.PARTIALLY_DEPLOYED.toString()); + }; + waitResponseMeetingCondition("GET", "/_plugins/_ml/models/" + modelId, (String) null, condition); + return; + } + + @SneakyThrows + protected void deleteModel(String modelId) { + // need to undeploy first as model can be in use + makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_undeploy", null, (String) null, null); + waitModelUndeployed(modelId); + makeRequest(client(), "DELETE", "/_plugins/_ml/models/" + modelId, null, (String) null, null); + } + + protected void createIndexWithConfiguration(String indexName, String indexConfiguration) throws Exception { + boolean indexExists = indexExists(indexName); + if (!indexExists) { + Response response = makeRequest(client(), "PUT", indexName, null, indexConfiguration, null); + Map responseInMap = parseResponseToMap(response); + assertEquals("true", responseInMap.get("acknowledged").toString()); + assertEquals(indexName, responseInMap.get("index").toString()); + } + } + + protected void createIngestPipelineWithConfiguration(String pipelineName, String body) throws Exception { + Response response = makeRequest(client(), "PUT", "/_ingest/pipeline/" + pipelineName, null, body, null); + Map responseInMap = parseResponseToMap(response); + assertEquals("true", responseInMap.get("acknowledged").toString()); + } + + // Similar to deleteExternalIndices, but including indices with "." prefix vs. excluding them + protected void deleteSystemIndices() throws IOException { + final Response response = client().performRequest(new Request("GET", "/_cat/indices?format=json" + "&expand_wildcards=all")); + try ( + final XContentParser parser = JsonXContent.jsonXContent + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + response.getEntity().getContent() + ) + ) { + final XContentParser.Token token = parser.nextToken(); + final List> parserList; + if (token == XContentParser.Token.START_ARRAY) { + parserList = parser.listOrderedMap().stream().map(obj -> (Map) obj).collect(Collectors.toList()); + } else { + parserList = Collections.singletonList(parser.mapOrdered()); + } + + final List externalIndices = parserList + .stream() + .map(index -> (String) index.get("index")) + .filter(indexName -> indexName != null) + .filter(indexName -> indexName.startsWith(".")) + .collect(Collectors.toList()); + + for (final String indexName : externalIndices) { + Response deleteResponse = adminClient().performRequest(new Request("DELETE", "/" + indexName)); + Map responseInMap = parseResponseToMap(deleteResponse); + assertEquals( + String.format(Locale.ROOT, "delete index %s failed with response: %s", indexName, gson.toJson(responseInMap)), + "true", + responseInMap.get("acknowledged").toString() + ); + } + } + } + + @SneakyThrows + protected void addDocToIndex(String indexName, String docId, List fieldNames, List fieldContents) { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + for (int i = 0; i < fieldNames.size(); i++) { + builder.field(fieldNames.get(i), fieldContents.get(i)); + } + builder.endObject(); + Response response = makeRequest( + client(), + "POST", + "/" + indexName + "/_doc/" + docId + "?refresh=true", + null, + builder.toString(), + null + ); + assertEquals(RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + @SneakyThrows + protected void addDocToIndex(String indexName, String docId, String contents) { + Response response = makeRequest(client(), "POST", "/" + indexName + "/_doc/" + docId + "?refresh=true", null, contents, null); + assertEquals(RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + public String createAgent(String requestBody) { + Response response = makeRequest(client(), "POST", "/_plugins/_ml/agents/_register", null, requestBody, null); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + return parseFieldFromResponse(response, AgentMLInput.AGENT_ID_FIELD).toString(); + } + + private String parseStringResponseFromExecuteAgentResponse(Response response) { + Map responseInMap = parseResponseToMap(response); + Optional optionalResult = Optional + .ofNullable(responseInMap) + .map(m -> (List) m.get(ModelTensorOutput.INFERENCE_RESULT_FIELD)) + .map(l -> (Map) l.get(0)) + .map(m -> (List) m.get(ModelTensors.OUTPUT_FIELD)) + .map(l -> (Map) l.get(0)) + .map(m -> (String) (m.get(ModelTensor.RESULT_FIELD))); + return optionalResult.get(); + } + + // execute the agent, and return the String response from the json structure + // {"inference_results": [{"output": [{"name": "response","result": "the result to return."}]}]} + public String executeAgent(String agentId, String requestBody) { + Response response = makeRequest(client(), "POST", "/_plugins/_ml/agents/" + agentId + "/_execute", null, requestBody, null); + return parseStringResponseFromExecuteAgentResponse(response); + } + + public static Response makeRequest( + RestClient client, + String method, + String endpoint, + Map params, + String jsonEntity, + List
headers + ) { + HttpEntity httpEntity = StringUtils.isBlank(jsonEntity) ? null : new StringEntity(jsonEntity, ContentType.APPLICATION_JSON); + return makeRequest(client, method, endpoint, params, httpEntity, headers); + } + + public static Response makeRequest( + RestClient client, + String method, + String endpoint, + Map params, + HttpEntity entity, + List
headers + ) { + return makeRequest(client, method, endpoint, params, entity, headers, false); + } + + @SneakyThrows + public static Response makeRequest( + RestClient client, + String method, + String endpoint, + Map params, + HttpEntity entity, + List
headers, + boolean strictDeprecationMode + ) { + Request request = new Request(method, endpoint); + + RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder(); + if (headers != null) { + headers.forEach(header -> options.addHeader(header.getName(), header.getValue())); + } + options.setWarningsHandler(strictDeprecationMode ? WarningsHandler.STRICT : WarningsHandler.PERMISSIVE); + request.setOptions(options.build()); + + if (params != null) { + params.forEach(request::addParameter); + } + if (entity != null) { + request.setEntity(entity); + } + return client.performRequest(request); + } +} diff --git a/src/test/java/org/opensearch/integTest/CreateAnomalyDetectorToolIT.java b/src/test/java/org/opensearch/integTest/CreateAnomalyDetectorToolIT.java new file mode 100644 index 00000000..648a381b --- /dev/null +++ b/src/test/java/org/opensearch/integTest/CreateAnomalyDetectorToolIT.java @@ -0,0 +1,345 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; + +import org.hamcrest.MatcherAssert; +import org.opensearch.agent.tools.CreateAnomalyDetectorTool; +import org.opensearch.client.ResponseException; + +import lombok.SneakyThrows; + +public class CreateAnomalyDetectorToolIT extends ToolIntegrationTest { + private final String NORMAL_INDEX = "http_logs"; + private final String NORMAL_INDEX_WITH_NO_AVAILABLE_FIELDS = "products"; + private final String NORMAL_INDEX_WITH_NO_DATE_FIELDS = "normal_index_with_no_date_fields"; + private final String NORMAL_INDEX_WITH_NO_MAPPING = "normal_index_with_no_mapping"; + private final String ABNORMAL_INDEX = "abnormal_index"; + + @Override + List promptHandlers() { + PromptHandler createAnomalyDetectorToolHandler = new PromptHandler() { + @Override + String response(String prompt) { + int flag; + if (prompt.contains(NORMAL_INDEX)) { + flag = randomIntBetween(0, 9); + switch (flag) { + case 0: + return "{category_field=|aggregation_field=response,responseLatency|aggregation_method=count,avg}"; + case 1: + return "{category_field=ip|aggregation_field=response,responseLatency|aggregation_method=count,avg}"; + case 2: + return "{category_field=|aggregation_field=responseLatency|aggregation_method=avg}"; + case 3: + return "{category_field=country.keyword|aggregation_field=response,responseLatency|aggregation_method=count,avg}"; + case 4: + return "{category_field=country.keyword|aggregation_field=response.keyword|aggregation_method=count}"; + case 5: + return "{category_field=\"country.keyword\"|aggregation_field=\"response,responseLatency\"|aggregation_method=\"count,avg\"}"; + case 6: + return "{category_field=ip|aggregation_field=responseLatency|aggregation_method=avg}"; + case 7: + return "{category_field=\"ip\"|aggregation_field=\"responseLatency\"|aggregation_method=\"avg\"}"; + case 8: + return "{category_field= ip |aggregation_field= responseLatency |aggregation_method= avg }"; + case 9: + return "{category_field=\" ip \"|aggregation_field=\" responseLatency \"|aggregation_method=\" avg \"}"; + default: + return "{category_field=|aggregation_field=response|aggregation_method=count}"; + } + } else if (prompt.contains(NORMAL_INDEX_WITH_NO_AVAILABLE_FIELDS)) { + flag = randomIntBetween(0, 9); + switch (flag) { + case 0: + return "{category_field=|aggregation_field=|aggregation_method=}"; + case 1: + return "{category_field= |aggregation_field= |aggregation_method= }"; + case 2: + return "{category_field=\"\"|aggregation_field=\"\"|aggregation_method=\"\"}"; + case 3: + return "{category_field=product|aggregation_field=|aggregation_method=sum}"; + case 4: + return "{category_field=product|aggregation_field=sales|aggregation_method=}"; + case 5: + return "{category_field=product|aggregation_field=\"\"|aggregation_method=sum}"; + case 6: + return "{category_field=product|aggregation_field=sales|aggregation_method=\"\"}"; + case 7: + return "{category_field=product|aggregation_field= |aggregation_method=sum}"; + case 8: + return "{category_field=product|aggregation_field=sales |aggregation_method= }"; + case 9: + return "{category_field=\"\"|aggregation_field= |aggregation_method=\"\" }"; + default: + return "{category_field=product|aggregation_field= |aggregation_method= }"; + } + } else { + flag = randomIntBetween(0, 1); + switch (flag) { + case 0: + return "wrong response"; + case 1: + return "{category_field=product}"; + default: + return "{category_field=}"; + } + } + } + + @Override + boolean apply(String prompt) { + return true; + } + }; + return List.of(createAnomalyDetectorToolHandler); + } + + @Override + String toolType() { + return CreateAnomalyDetectorTool.TYPE; + } + + public void testCreateAnomalyDetectorTool() { + prepareIndex(); + String agentId = registerAgent(); + String index; + if (randomIntBetween(0, 1) == 0) { + index = NORMAL_INDEX; + } else { + index = NORMAL_INDEX_WITH_NO_AVAILABLE_FIELDS; + } + String result = executeAgent(agentId, "{\"parameters\": {\"index\":\"" + index + "\"}}"); + assertTrue(result.contains("index")); + assertTrue(result.contains("categoryField")); + assertTrue(result.contains("aggregationField")); + assertTrue(result.contains("aggregationMethod")); + assertTrue(result.contains("dateFields")); + } + + public void testCreateAnomalyDetectorToolWithNonExistentModelId() { + prepareIndex(); + String agentId = registerAgentWithWrongModelId(); + Exception exception = assertThrows( + ResponseException.class, + () -> executeAgent(agentId, "{\"parameters\": {\"index\":\"" + ABNORMAL_INDEX + "\"}}") + ); + MatcherAssert.assertThat(exception.getMessage(), allOf(containsString("Failed to find model"))); + } + + public void testCreateAnomalyDetectorToolWithUnexpectedResult() { + prepareIndex(); + String agentId = registerAgent(); + + Exception exception = assertThrows( + ResponseException.class, + () -> executeAgent(agentId, "{\"parameters\": {\"index\":\"" + NORMAL_INDEX_WITH_NO_MAPPING + "\"}}") + ); + MatcherAssert + .assertThat( + exception.getMessage(), + allOf( + containsString( + "The index " + + NORMAL_INDEX_WITH_NO_MAPPING + + " doesn't have mapping metadata, please add data to it or using another index." + ) + ) + ); + + exception = assertThrows( + ResponseException.class, + () -> executeAgent(agentId, "{\"parameters\": {\"index\":\"" + NORMAL_INDEX_WITH_NO_DATE_FIELDS + "\"}}") + ); + MatcherAssert + .assertThat( + exception.getMessage(), + allOf( + containsString( + "The index " + + NORMAL_INDEX_WITH_NO_DATE_FIELDS + + " doesn't have date type fields, cannot create an anomaly detector for it." + ) + ) + ); + + exception = assertThrows( + ResponseException.class, + () -> executeAgent(agentId, "{\"parameters\": {\"index\":\"" + ABNORMAL_INDEX + "\"}}") + ); + MatcherAssert + .assertThat( + exception.getMessage(), + allOf( + containsString( + "The inference result from remote endpoint is not valid, cannot extract the key information from the result." + ) + ) + ); + } + + public void testCreateAnomalyDetectorToolWithSystemIndex() { + prepareIndex(); + String agentId = registerAgent(); + Exception exception = assertThrows( + ResponseException.class, + () -> executeAgent(agentId, "{\"parameters\": {\"index\": \".test\"}}") + ); + MatcherAssert + .assertThat( + exception.getMessage(), + allOf( + containsString( + "CreateAnomalyDetectionTool doesn't support searching indices starting with '.' since it could be system index, current searching index name: .test" + ) + ) + ); + } + + public void testCreateAnomalyDetectorToolWithMissingIndex() { + prepareIndex(); + String agentId = registerAgent(); + Exception exception = assertThrows( + ResponseException.class, + () -> executeAgent(agentId, "{\"parameters\": {\"index\": \"non-existent\"}}") + ); + MatcherAssert + .assertThat( + exception.getMessage(), + allOf( + containsString( + "Return this final answer to human directly and do not use other tools: 'The index doesn't exist, please provide another index and retry'. Please try to directly send this message to human to ask for index name" + ) + ) + ); + } + + public void testCreateAnomalyDetectorToolWithEmptyInput() { + prepareIndex(); + String agentId = registerAgent(); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {}}")); + MatcherAssert + .assertThat( + exception.getMessage(), + allOf( + containsString( + "Return this final answer to human directly and do not use other tools: 'Please provide index name'. Please try to directly send this message to human to ask for index name" + ) + ) + ); + } + + @SneakyThrows + private void prepareIndex() { + createIndexWithConfiguration( + NORMAL_INDEX, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"response\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"responseLatency\": {\n" + + " \"type\": \"float\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + addDocToIndex(NORMAL_INDEX, "0", List.of("response", "responseLatency", "date"), List.of(200, 0.15, "2024-07-03T10:22:56,520")); + addDocToIndex(NORMAL_INDEX, "1", List.of("response", "responseLatency", "date"), List.of(200, 3.15, "2024-07-03T10:22:57,520")); + + createIndexWithConfiguration( + NORMAL_INDEX_WITH_NO_AVAILABLE_FIELDS, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"product\": {\n" + + " " + + " \"type\": \"keyword\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + addDocToIndex(NORMAL_INDEX_WITH_NO_AVAILABLE_FIELDS, "0", List.of("product", "date"), List.of(1, "2024-07-03T10:22:56,520")); + addDocToIndex(NORMAL_INDEX_WITH_NO_AVAILABLE_FIELDS, "1", List.of("product", "date"), List.of(2, "2024-07-03T10:22:57,520")); + + createIndexWithConfiguration( + NORMAL_INDEX_WITH_NO_DATE_FIELDS, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"product\": {\n" + + " " + + " \"type\": \"keyword\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + addDocToIndex(NORMAL_INDEX_WITH_NO_DATE_FIELDS, "0", List.of("product"), List.of(1)); + addDocToIndex(NORMAL_INDEX_WITH_NO_DATE_FIELDS, "1", List.of("product"), List.of(2)); + + createIndexWithConfiguration(NORMAL_INDEX_WITH_NO_MAPPING, "{}"); + + createIndexWithConfiguration( + ABNORMAL_INDEX, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"date\": {\n" + + " " + + " \"type\": \"date\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + addDocToIndex(ABNORMAL_INDEX, "0", List.of("date"), List.of(1, "2024-07-03T10:22:56,520")); + addDocToIndex(ABNORMAL_INDEX, "1", List.of("date"), List.of(2, "2024-07-03T10:22:57,520")); + } + + @SneakyThrows + private String registerAgentWithWrongModelId() { + String registerAgentRequestBody = Files + .readString( + Path + .of( + this + .getClass() + .getClassLoader() + .getResource("org/opensearch/agent/tools/register_flow_agent_of_create_anomaly_detector_tool_request_body.json") + .toURI() + ) + ); + registerAgentRequestBody = registerAgentRequestBody.replace("", "non-existent"); + return createAgent(registerAgentRequestBody); + } + + @SneakyThrows + private String registerAgent() { + String registerAgentRequestBody = Files + .readString( + Path + .of( + this + .getClass() + .getClassLoader() + .getResource("org/opensearch/agent/tools/register_flow_agent_of_create_anomaly_detector_tool_request_body.json") + .toURI() + ) + ); + registerAgentRequestBody = registerAgentRequestBody.replace("", modelId); + return createAgent(registerAgentRequestBody); + } +} diff --git a/src/test/java/org/opensearch/integTest/MockHttpServer.java b/src/test/java/org/opensearch/integTest/MockHttpServer.java new file mode 100644 index 00000000..f64adcd1 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/MockHttpServer.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import java.io.IOException; +import java.io.InputStream; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import com.google.gson.Gson; +import com.sun.net.httpserver.HttpServer; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class MockHttpServer { + + private static Gson gson = new Gson(); + + public static HttpServer setupMockLLM(List promptHandlers) throws IOException { + HttpServer server = HttpServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0); + + server.createContext("/invoke", exchange -> { + InputStream ins = exchange.getRequestBody(); + String req = new String(ins.readAllBytes(), StandardCharsets.UTF_8); + Map map = gson.fromJson(req, Map.class); + String prompt = map.get("prompt"); + log.debug("prompt received: {}", prompt); + + String llmRes = ""; + for (PromptHandler promptHandler : promptHandlers) { + if (promptHandler.apply(prompt)) { + PromptHandler.LLMResponse llmResponse = new PromptHandler.LLMResponse(); + llmResponse.setCompletion(promptHandler.response(prompt)); + llmRes = gson.toJson(llmResponse); + break; + } + } + byte[] llmResBytes = llmRes.getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(200, llmResBytes.length); + exchange.getResponseBody().write(llmResBytes); + exchange.close(); + }); + return server; + } +} diff --git a/src/test/java/org/opensearch/integTest/NeuralSparseSearchToolIT.java b/src/test/java/org/opensearch/integTest/NeuralSparseSearchToolIT.java new file mode 100644 index 00000000..b7618468 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/NeuralSparseSearchToolIT.java @@ -0,0 +1,233 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.client.ResponseException; + +import lombok.SneakyThrows; + +public class NeuralSparseSearchToolIT extends BaseAgentToolsIT { + public static String TEST_INDEX_NAME = "test_index"; + public static String TEST_NESTED_INDEX_NAME = "test_index_nested"; + + private String modelId; + private String registerAgentRequestBody; + + @SneakyThrows + private void prepareModel() { + String requestBody = Files + .readString( + Path + .of( + this + .getClass() + .getClassLoader() + .getResource("org/opensearch/agent/tools/register_sparse_encoding_model_request_body.json") + .toURI() + ) + ); + modelId = registerModelThenDeploy(requestBody); + } + + @SneakyThrows + private void prepareIndex() { + createIndexWithConfiguration( + TEST_INDEX_NAME, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"text\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"embedding\": {\n" + + " \"type\": \"rank_features\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + addDocToIndex(TEST_INDEX_NAME, "0", List.of("text", "embedding"), List.of("text doc 1", Map.of("hello", 1, "world", 2))); + addDocToIndex(TEST_INDEX_NAME, "1", List.of("text", "embedding"), List.of("text doc 2", Map.of("a", 3, "b", 4))); + addDocToIndex(TEST_INDEX_NAME, "2", List.of("text", "embedding"), List.of("text doc 3", Map.of("test", 5, "a", 6))); + } + + @SneakyThrows + private void prepareNestedIndex() { + createIndexWithConfiguration( + TEST_NESTED_INDEX_NAME, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"text\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"embedding\": {\n" + + " \"type\": \"nested\",\n" + + " \"properties\":{\n" + + " \"sparse\":{\n" + + " \"type\":\"rank_features\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + addDocToIndex( + TEST_NESTED_INDEX_NAME, + "0", + List.of("text", "embedding"), + List.of("text doc 1", Map.of("sparse", List.of(Map.of("hello", 1, "world", 2)))) + ); + addDocToIndex( + TEST_NESTED_INDEX_NAME, + "1", + List.of("text", "embedding"), + List.of("text doc 2", Map.of("sparse", List.of(Map.of("a", 3, "b", 4)))) + ); + addDocToIndex( + TEST_NESTED_INDEX_NAME, + "2", + List.of("text", "embedding"), + List.of("text doc 3", Map.of("sparse", List.of(Map.of("test", 5, "a", 6)))) + ); + } + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + prepareModel(); + prepareIndex(); + prepareNestedIndex(); + registerAgentRequestBody = Files + .readString( + Path + .of( + this + .getClass() + .getClassLoader() + .getResource("org/opensearch/agent/tools/register_flow_agent_of_neural_sparse_search_tool_request_body.json") + .toURI() + ) + ); + registerAgentRequestBody = registerAgentRequestBody.replace("", modelId); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteExternalIndices(); + deleteModel(modelId); + } + + public void testNeuralSparseSearchToolInFlowAgent() { + String agentId = createAgent(registerAgentRequestBody); + // successful case + String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"); + assertEquals( + "The agent execute response not equal with expected.", + "{\"_index\":\"test_index\",\"_source\":{\"text\":\"text doc 3\"},\"_id\":\"2\",\"_score\":2.4136734}\n" + + "{\"_index\":\"test_index\",\"_source\":{\"text\":\"text doc 2\"},\"_id\":\"1\",\"_score\":1.2068367}\n", + result + ); + + // use non-exist token to test the case the tool can not find match docs. + String result2 = executeAgent(agentId, "{\"parameters\": {\"question\": \"c\"}}"); + assertEquals("The agent execute response not equal with expected.", "Can not get any match from search result.", result2); + + // if blank input, call onFailure and get exception + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"\"}}")); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf(containsString("[input] is null or empty, can not process it."), containsString("IllegalArgumentException")) + ); + + // use json string input + String jsonInput = gson.toJson(Map.of("parameters", Map.of("question", gson.toJson(Map.of("hi", "a"))))); + String result3 = executeAgent(agentId, jsonInput); + assertEquals( + "The agent execute response not equal with expected.", + "{\"_index\":\"test_index\",\"_source\":{\"text\":\"text doc 3\"},\"_id\":\"2\",\"_score\":2.4136734}\n" + + "{\"_index\":\"test_index\",\"_source\":{\"text\":\"text doc 2\"},\"_id\":\"1\",\"_score\":1.2068367}\n", + result3 + ); + } + + public void testNeuralSparseSearchToolInFlowAgent_withNestedIndex() { + String registerAgentRequestBodyNested = registerAgentRequestBody; + registerAgentRequestBodyNested = registerAgentRequestBodyNested.replace("\"nested_path\": \"\"", "\"nested_path\": \"embedding\""); + registerAgentRequestBodyNested = registerAgentRequestBodyNested + .replace("\"embedding_field\": \"embedding\"", "\"embedding_field\": \"embedding.sparse\""); + registerAgentRequestBodyNested = registerAgentRequestBodyNested + .replace("\"index\": \"test_index\"", "\"index\": \"test_index_nested\""); + String agentId = createAgent(registerAgentRequestBodyNested); + String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"); + assertEquals( + "The agent execute response not equal with expected.", + "{\"_index\":\"test_index_nested\",\"_source\":{\"text\":\"text doc 3\"},\"_id\":\"2\",\"_score\":2.4136734}\n" + + "{\"_index\":\"test_index_nested\",\"_source\":{\"text\":\"text doc 2\"},\"_id\":\"1\",\"_score\":1.2068367}\n", + result + ); + } + + public void testNeuralSparseSearchToolInFlowAgent_withIllegalSourceField_thenGetEmptySource() { + String agentId = createAgent(registerAgentRequestBody.replace("text", "text2")); + String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"); + assertEquals( + "The agent execute response not equal with expected.", + "{\"_index\":\"test_index\",\"_source\":{},\"_id\":\"2\",\"_score\":2.4136734}\n" + + "{\"_index\":\"test_index\",\"_source\":{},\"_id\":\"1\",\"_score\":1.2068367}\n", + result + ); + } + + public void testNeuralSparseSearchToolInFlowAgent_withIllegalEmbeddingField_thenThrowException() { + String agentId = createAgent(registerAgentRequestBody.replace("\"embedding\"", "\"embedding2\"")); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf( + containsString("[neural_sparse] query only works on [rank_features] fields"), + containsString("IllegalArgumentException") + ) + ); + } + + public void testNeuralSparseSearchToolInFlowAgent_withIllegalIndexField_thenThrowException() { + String agentId = createAgent(registerAgentRequestBody.replace("test_index", "test_index2")); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf(containsString("no such index [test_index2]"), containsString("IndexNotFoundException")) + ); + } + + public void testNeuralSparseSearchToolInFlowAgent_withIllegalModelIdField_thenThrowException() { + String agentId = createAgent(registerAgentRequestBody.replace(modelId, "test_model_id")); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + + org.hamcrest.MatcherAssert + .assertThat(exception.getMessage(), allOf(containsString("Failed to find model"), containsString("OpenSearchStatusException"))); + } +} diff --git a/src/test/java/org/opensearch/integTest/OpenSearchSecureRestTestCase.java b/src/test/java/org/opensearch/integTest/OpenSearchSecureRestTestCase.java new file mode 100644 index 00000000..2838f1f2 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/OpenSearchSecureRestTestCase.java @@ -0,0 +1,163 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import org.apache.http.Header; +import org.apache.http.HttpHost; +import org.apache.http.auth.AuthScope; +import org.apache.http.auth.UsernamePasswordCredentials; +import org.apache.http.client.CredentialsProvider; +import org.apache.http.conn.ssl.NoopHostnameVerifier; +import org.apache.http.impl.client.BasicCredentialsProvider; +import org.apache.http.message.BasicHeader; +import org.apache.http.ssl.SSLContextBuilder; +import org.junit.After; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.client.RestClient; +import org.opensearch.client.RestClientBuilder; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.test.rest.OpenSearchRestTestCase; + +/** + * Base class for running the integration tests on a secure cluster. The plugin IT test should either extend this + * class or create another base class by extending this class to make sure that their IT can be run on secure clusters. + */ +public abstract class OpenSearchSecureRestTestCase extends OpenSearchRestTestCase { + + private static final String PROTOCOL_HTTP = "http"; + private static final String PROTOCOL_HTTPS = "https"; + private static final String SYS_PROPERTY_KEY_HTTPS = "https"; + private static final String SYS_PROPERTY_KEY_CLUSTER_ENDPOINT = "tests.rest.cluster"; + private static final String SYS_PROPERTY_KEY_USER = "user"; + private static final String SYS_PROPERTY_KEY_PASSWORD = "password"; + private static final String DEFAULT_SOCKET_TIMEOUT = "60s"; + private static final String INTERNAL_INDICES_PREFIX = "."; + private static String protocol; + + @Override + protected String getProtocol() { + if (protocol == null) { + protocol = readProtocolFromSystemProperty(); + } + return protocol; + } + + private String readProtocolFromSystemProperty() { + final boolean isHttps = Optional.ofNullable(System.getProperty(SYS_PROPERTY_KEY_HTTPS)).map("true"::equalsIgnoreCase).orElse(false); + if (!isHttps) { + return PROTOCOL_HTTP; + } + + // currently only external cluster is supported for security enabled testing + if (Optional.ofNullable(System.getProperty(SYS_PROPERTY_KEY_CLUSTER_ENDPOINT)).isEmpty()) { + throw new RuntimeException("cluster url should be provided for security enabled testing"); + } + return PROTOCOL_HTTPS; + } + + @Override + protected RestClient buildClient(Settings settings, HttpHost[] hosts) throws IOException { + final RestClientBuilder builder = RestClient.builder(hosts); + if (PROTOCOL_HTTPS.equals(getProtocol())) { + configureHttpsClient(builder, settings); + } else { + configureClient(builder, settings); + } + + return builder.build(); + } + + private void configureHttpsClient(final RestClientBuilder builder, final Settings settings) { + final Map headers = ThreadContext.buildDefaultHeaders(settings); + final Header[] defaultHeaders = new Header[headers.size()]; + int i = 0; + for (Map.Entry entry : headers.entrySet()) { + defaultHeaders[i++] = new BasicHeader(entry.getKey(), entry.getValue()); + } + builder.setDefaultHeaders(defaultHeaders); + builder.setHttpClientConfigCallback(httpClientBuilder -> { + final String userName = Optional + .ofNullable(System.getProperty(SYS_PROPERTY_KEY_USER)) + .orElseThrow(() -> new RuntimeException("user name is missing")); + final String password = Optional + .ofNullable(System.getProperty(SYS_PROPERTY_KEY_PASSWORD)) + .orElseThrow(() -> new RuntimeException("password is missing")); + final CredentialsProvider credentialsProvider = new BasicCredentialsProvider(); + credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(userName, password)); + try { + return httpClientBuilder + .setDefaultCredentialsProvider(credentialsProvider) + // disable the certificate since our testing cluster just uses the default security configuration + .setSSLHostnameVerifier(NoopHostnameVerifier.INSTANCE) + .setSSLContext(SSLContextBuilder.create().loadTrustMaterial(null, (chains, authType) -> true).build()); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + + final String socketTimeoutString = settings.get(CLIENT_SOCKET_TIMEOUT); + final TimeValue socketTimeout = TimeValue + .parseTimeValue(socketTimeoutString == null ? DEFAULT_SOCKET_TIMEOUT : socketTimeoutString, CLIENT_SOCKET_TIMEOUT); + builder.setRequestConfigCallback(conf -> conf.setSocketTimeout(Math.toIntExact(socketTimeout.getMillis()))); + if (settings.hasValue(CLIENT_PATH_PREFIX)) { + builder.setPathPrefix(settings.get(CLIENT_PATH_PREFIX)); + } + } + + /** + * wipeAllIndices won't work since it cannot delete security index. Use deleteExternalIndices instead. + */ + @Override + protected boolean preserveIndicesUponCompletion() { + return true; + } + + @After + public void deleteExternalIndices() throws IOException { + final Response response = client().performRequest(new Request("GET", "/_cat/indices?format=json" + "&expand_wildcards=all")); + try ( + final XContentParser parser = JsonXContent.jsonXContent + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + response.getEntity().getContent() + ) + ) { + final XContentParser.Token token = parser.nextToken(); + final List> parserList; + if (token == XContentParser.Token.START_ARRAY) { + parserList = parser.listOrderedMap().stream().map(obj -> (Map) obj).collect(Collectors.toList()); + } else { + parserList = Collections.singletonList(parser.mapOrdered()); + } + + final List externalIndices = parserList + .stream() + .map(index -> (String) index.get("index")) + .filter(indexName -> indexName != null) + .filter(indexName -> !indexName.startsWith(INTERNAL_INDICES_PREFIX)) + .collect(Collectors.toList()); + + for (final String indexName : externalIndices) { + adminClient().performRequest(new Request("DELETE", "/" + indexName)); + } + } + } +} diff --git a/src/test/java/org/opensearch/integTest/PPLToolIT.java b/src/test/java/org/opensearch/integTest/PPLToolIT.java new file mode 100644 index 00000000..46cbd864 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/PPLToolIT.java @@ -0,0 +1,191 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; + +import org.hamcrest.MatcherAssert; +import org.opensearch.agent.tools.PPLTool; +import org.opensearch.client.ResponseException; + +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class PPLToolIT extends ToolIntegrationTest { + + @Override + List promptHandlers() { + PromptHandler PPLHandler = new PromptHandler() { + @Override + String response(String prompt) { + if (prompt.contains("correct")) { + return "source=employee | where age > 56 | stats COUNT() as cnt"; + } else { + return "source=employee | asd"; + } + } + + @Override + boolean apply(String prompt) { + return true; + } + }; + return List.of(PPLHandler); + } + + @Override + String toolType() { + return PPLTool.TYPE; + } + + @SneakyThrows + public void testPPLTool() { + prepareIndex(); + String agentId = registerAgent(); + String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"correct\", \"index\": \"employee\"}}"); + assertEquals( + "{\"ppl\":\"source\\u003demployee| where age \\u003e 56 | stats COUNT() as cnt\",\"executionResult\":\"{\\n \\\"schema\\\": [\\n {\\n \\\"name\\\": \\\"cnt\\\",\\n \\\"type\\\": \\\"integer\\\"\\n }\\n ],\\n \\\"datarows\\\": [\\n [\\n 0\\n ]\\n ],\\n \\\"total\\\": 1,\\n \\\"size\\\": 1\\n}\"}", + result + ); + } + + public void testPPLTool_withWrongPPLGenerated_thenThrowException() { + prepareIndex(); + String agentId = registerAgent(); + Exception exception = assertThrows( + ResponseException.class, + () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"wrong\", \"index\": \"employee\"}}") + ); + MatcherAssert.assertThat(exception.getMessage(), allOf(containsString("execute ppl:source=employee| asd, get error"))); + + } + + public void testPPLTool_withWrongModelId_thenThrowException() { + prepareIndex(); + String agentId = registerAgentWithWrongModelId(); + Exception exception = assertThrows( + ResponseException.class, + () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"correct\", \"index\": \"employee\"}}") + ); + MatcherAssert.assertThat(exception.getMessage(), allOf(containsString("Failed to find model"))); + + } + + public void testPPLTool_withSystemQuery_thenThrowException() { + prepareIndex(); + String agentId = registerAgent(); + Exception exception = assertThrows( + ResponseException.class, + () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"correct\", \"index\": \".employee\"}}") + ); + MatcherAssert + .assertThat( + exception.getMessage(), + allOf( + containsString( + "PPLTool doesn't support searching indices starting with '.' since it could be system index, current searching index name: .employee" + ) + ) + ); + + } + + public void testPPLTool_withNonExistingIndex_thenThrowException() { + prepareIndex(); + String agentId = registerAgent(); + Exception exception = assertThrows( + ResponseException.class, + () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"correct\", \"index\": \"employee2\"}}") + ); + MatcherAssert + .assertThat( + exception.getMessage(), + allOf( + containsString( + "Return this final answer to human directly and do not use other tools: 'Please provide index name'. Please try to directly send this message to human to ask for index name" + ) + ) + ); + } + + public void testPPLTool_withBlankInput_thenThrowException() { + prepareIndex(); + String agentId = registerAgent(); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + MatcherAssert + .assertThat( + exception.getMessage(), + allOf( + containsString( + "Return this final answer to human directly and do not use other tools: 'Please provide index name'. Please try to directly send this message to human to ask for index name" + ) + ) + ); + } + + @SneakyThrows + private String registerAgent() { + String registerAgentRequestBody = Files + .readString( + Path + .of( + this + .getClass() + .getClassLoader() + .getResource("org/opensearch/agent/tools/register_flow_agent_of_ppl_tool_request_body.json") + .toURI() + ) + ); + registerAgentRequestBody = registerAgentRequestBody.replace("", modelId); + return createAgent(registerAgentRequestBody); + } + + @SneakyThrows + private String registerAgentWithWrongModelId() { + String registerAgentRequestBody = Files + .readString( + Path + .of( + this + .getClass() + .getClassLoader() + .getResource("org/opensearch/agent/tools/register_flow_agent_of_ppl_tool_request_body.json") + .toURI() + ) + ); + registerAgentRequestBody = registerAgentRequestBody.replace("", "wrong_model_id"); + return createAgent(registerAgentRequestBody); + } + + @SneakyThrows + private void prepareIndex() { + String testIndexName = "employee"; + createIndexWithConfiguration( + testIndexName, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"age\": {\n" + + " \"type\": \"long\"\n" + + " },\n" + + " \"name\": {\n" + + " \"type\": \"text\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + addDocToIndex(testIndexName, "0", List.of("age", "name"), List.of(56, "john")); + addDocToIndex(testIndexName, "1", List.of("age", "name"), List.of(56, "smith")); + } + +} diff --git a/src/test/java/org/opensearch/integTest/PromptHandler.java b/src/test/java/org/opensearch/integTest/PromptHandler.java new file mode 100644 index 00000000..0ef03501 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/PromptHandler.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import com.google.gson.annotations.SerializedName; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +public class PromptHandler { + + boolean apply(String prompt) { + return prompt.contains(llmThought().getQuestion()); + } + + LLMThought llmThought() { + return new LLMThought(); + } + + String response(String prompt) { + if (prompt.contains("Human: TOOL RESPONSE ")) { + return "```json{\n" + + " \"thought\": \"Thought: Now I know the final answer\",\n" + + " \"final_answer\": \"final answer\"\n" + + "}```"; + } else { + return "```json{\n" + + " \"thought\": \"Thought: Let me use tool to figure out\",\n" + + " \"action\": \"" + + this.llmThought().getAction() + + "\",\n" + + " \"action_input\": \"" + + this.llmThought().getActionInput() + + "\"\n" + + "}```"; + } + } + + @Builder + @NoArgsConstructor + @AllArgsConstructor + @Data + static class LLMThought { + String question; + String action; + String actionInput; + } + + @Data + static class LLMResponse { + String completion; + @SerializedName("stop_reason") + String stopReason = "stop_sequence"; + String stop = "\\n\\nHuman:"; + } +} diff --git a/src/test/java/org/opensearch/integTest/RAGToolIT.java b/src/test/java/org/opensearch/integTest/RAGToolIT.java new file mode 100644 index 00000000..eae20755 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/RAGToolIT.java @@ -0,0 +1,521 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.agent.tools.RAGTool; +import org.opensearch.client.ResponseException; + +import lombok.SneakyThrows; + +public class RAGToolIT extends ToolIntegrationTest { + + public static String TEST_NEURAL_INDEX_NAME = "test_neural_index"; + public static String TEST_NEURAL_SPARSE_INDEX_NAME = "test_neural_sparse_index"; + private String textEmbeddingModelId; + private String sparseEncodingModelId; + private String largeLanguageModelId; + private String registerAgentWithNeuralQueryRequestBody; + private String registerAgentWithNeuralSparseQueryRequestBody; + private String registerAgentWithNeuralQueryAndLLMRequestBody; + private String mockLLMResponseWithSource = "{\n" + + " \"inference_results\": [\n" + + " {\n" + + " \"output\": [\n" + + " {\n" + + " \"name\": \"response\",\n" + + " \"result\": \"\"\" Based on the context given:\n" + + " a, b, c are alphabets.\"\"\"\n" + + " }\n" + + " ]\n" + + " }\n" + + " ]\n" + + "}"; + private String mockLLMResponseWithoutSource = "{\n" + + " \"inference_results\": [\n" + + " {\n" + + " \"output\": [\n" + + " {\n" + + " \"name\": \"response\",\n" + + " \"result\": \"\"\" Based on the context given:\n" + + " I do not see any information about a, b, c\". So I would have to say I don't know the answer to your question based on this context..\"\"\"\n" + + " }\n" + + " ]\n" + + " }\n" + + " ]\n" + + "}"; + private String registerAgentWithNeuralSparseQueryAndLLMRequestBody; + + public RAGToolIT() throws IOException, URISyntaxException {} + + @SneakyThrows + private void prepareModel() { + String requestBody = Files + .readString( + Path + .of( + this + .getClass() + .getClassLoader() + .getResource("org/opensearch/agent/tools/register_text_embedding_model_request_body.json") + .toURI() + ) + ); + textEmbeddingModelId = registerModelThenDeploy(requestBody); + + String requestBody1 = Files + .readString( + Path + .of( + this + .getClass() + .getClassLoader() + .getResource("org/opensearch/agent/tools/register_sparse_encoding_model_request_body.json") + .toURI() + ) + ); + sparseEncodingModelId = registerModelThenDeploy(requestBody1); + largeLanguageModelId = this.modelId; + } + + @SneakyThrows + private void prepareIndex() { + // prepare index for neural sparse query type + createIndexWithConfiguration( + TEST_NEURAL_SPARSE_INDEX_NAME, + "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"text\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"embedding\": {\n" + + " \"type\": \"rank_features\"\n" + + " }\n" + + " }\n" + + " }\n" + + "}" + ); + addDocToIndex( + TEST_NEURAL_SPARSE_INDEX_NAME, + "0", + List.of("text", "embedding"), + List.of("hello world", Map.of("hello", 1, "world", 2)) + ); + addDocToIndex(TEST_NEURAL_SPARSE_INDEX_NAME, "1", List.of("text", "embedding"), List.of("a b", Map.of("a", 3, "b", 4))); + + // prepare index for neural query type + String pipelineConfig = "{\n" + + " \"description\": \"text embedding pipeline\",\n" + + " \"processors\": [\n" + + " {\n" + + " \"text_embedding\": {\n" + + " \"model_id\": \"" + + textEmbeddingModelId + + "\",\n" + + " \"field_map\": {\n" + + " \"text\": \"embedding\"\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + createIngestPipelineWithConfiguration("test-embedding-model", pipelineConfig); + + String indexMapping = "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"text\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"embedding\": {\n" + + " \"type\": \"knn_vector\",\n" + + " \"dimension\": 768,\n" + + " \"method\": {\n" + + " \"name\": \"hnsw\",\n" + + " \"space_type\": \"l2\",\n" + + " \"engine\": \"lucene\",\n" + + " \"parameters\": {\n" + + " \"ef_construction\": 128,\n" + + " \"m\": 24\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " \"settings\": {\n" + + " \"index\": {\n" + + " \"knn.space_type\": \"cosinesimil\",\n" + + " \"default_pipeline\": \"test-embedding-model\",\n" + + " \"knn\": \"true\"\n" + + " }\n" + + " }\n" + + "}"; + + createIndexWithConfiguration(TEST_NEURAL_INDEX_NAME, indexMapping); + + addDocToIndex(TEST_NEURAL_INDEX_NAME, "0", List.of("text"), List.of("hello world")); + + addDocToIndex(TEST_NEURAL_INDEX_NAME, "1", List.of("text"), List.of("a b")); + } + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + prepareModel(); + prepareIndex(); + String registerAgentWithNeuralQueryRequestBodyFile = Files + .readString( + Path + .of( + this + .getClass() + .getClassLoader() + .getResource( + "org/opensearch/agent/tools/register_flow_agent_of_ragtool_with_neural_query_type_request_body.json" + ) + .toURI() + ) + ); + registerAgentWithNeuralQueryRequestBody = registerAgentWithNeuralQueryRequestBodyFile + .replace("", textEmbeddingModelId) + .replace("", TEST_NEURAL_INDEX_NAME); + + registerAgentWithNeuralSparseQueryRequestBody = registerAgentWithNeuralQueryRequestBodyFile + .replace("", sparseEncodingModelId) + .replace("", TEST_NEURAL_SPARSE_INDEX_NAME) + .replace("\"query_type\": \"neural\"", "\"query_type\": \"neural_sparse\""); + + registerAgentWithNeuralQueryAndLLMRequestBody = registerAgentWithNeuralQueryRequestBodyFile + .replace("", textEmbeddingModelId + "\" ,\n \"inference_model_id\": \"" + largeLanguageModelId) + .replace("", TEST_NEURAL_INDEX_NAME) + .replace("false", "true"); + registerAgentWithNeuralSparseQueryAndLLMRequestBody = registerAgentWithNeuralQueryRequestBodyFile + .replace("", sparseEncodingModelId + "\" ,\n \"inference_model_id\": \"" + largeLanguageModelId) + .replace("", TEST_NEURAL_SPARSE_INDEX_NAME) + .replace("\"query_type\": \"neural\"", "\"query_type\": \"neural_sparse\"") + .replace("false", "true"); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteExternalIndices(); + deleteModel(textEmbeddingModelId); + deleteModel(sparseEncodingModelId); + } + + public void testRAGToolWithNeuralQuery() { + String agentId = createAgent(registerAgentWithNeuralQueryRequestBody); + + // neural query to test match similar text, doc1 match with higher score + String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"c\"}}"); + + // To allow digits variation from model output, using string contains to match + assertTrue( + result.contains("{\"_index\":\"test_neural_index\",\"_source\":{\"text\":\"hello world\"},\"_id\":\"0\",\"_score\":0.70467") + ); + assertTrue(result.contains("{\"_index\":\"test_neural_index\",\"_source\":{\"text\":\"a b\"},\"_id\":\"1\",\"_score\":0.26499")); + + // neural query to test match exact same text case, doc0 match with higher score + String result1 = executeAgent(agentId, "{\"parameters\": {\"question\": \"hello\"}}"); + + // To allow digits variation from model output, using string contains to match + assertTrue( + result1.contains("{\"_index\":\"test_neural_index\",\"_source\":{\"text\":\"hello world\"},\"_id\":\"0\",\"_score\":0.5671488") + ); + assertTrue(result1.contains("{\"_index\":\"test_neural_index\",\"_source\":{\"text\":\"a b\"},\"_id\":\"1\",\"_score\":0.2423683")); + + // if blank input, call onFailure and get exception + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"\"}}")); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf(containsString("[input] is null or empty, can not process it."), containsString("IllegalArgumentException")) + ); + + } + + public void testRAGToolWithNeuralQueryAndLLM() { + String agentId = createAgent(registerAgentWithNeuralQueryAndLLMRequestBody); + + // neural query to test match similar text, doc1 match with higher score + String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"use RAGTool to answer a\"}}"); + assertEquals(mockLLMResponseWithSource, result); + + // if blank input, call onFailure and get exception + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"\"}}")); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf(containsString("[input] is null or empty, can not process it."), containsString("IllegalArgumentException")) + ); + + } + + public void testRAGToolWithNeuralSparseQuery() { + String agentId = createAgent(registerAgentWithNeuralSparseQueryRequestBody); + + // neural sparse query to test match extract same text, doc1 match with high score + String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"); + assertEquals( + "The agent execute response not equal with expected.", + "{\"_index\":\"test_neural_sparse_index\",\"_source\":{\"text\":\"a b\"},\"_id\":\"1\",\"_score\":1.2068367}\n", + result + ); + + // neural sparse query to test match extract non-existed text, no match + String result2 = executeAgent(agentId, "{\"parameters\": {\"question\": \"c\"}}"); + assertEquals("The agent execute response not equal with expected.", "Can not get any match from search result.", result2); + + // if blank input, call onFailure and get exception + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"\"}}")); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf(containsString("[input] is null or empty, can not process it."), containsString("IllegalArgumentException")) + ); + } + + public void testRAGToolWithNeuralSparseQueryAndLLM() { + String agentId = createAgent(registerAgentWithNeuralSparseQueryAndLLMRequestBody); + + // neural sparse query to test match extract same text, doc1 match with high score + String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"use RAGTool to answer a\"}}"); + assertEquals(mockLLMResponseWithSource, result); + + // if blank input, call onFailure and get exception + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"\"}}")); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf(containsString("[input] is null or empty, can not process it."), containsString("IllegalArgumentException")) + ); + } + + public void testRAGToolWithNeuralSparseQuery_withIllegalSourceField_thenGetEmptySource() { + String agentId = createAgent(registerAgentWithNeuralSparseQueryRequestBody.replace("text", "text2")); + String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"); + assertEquals( + "The agent execute response not equal with expected.", + "{\"_index\":\"test_neural_sparse_index\",\"_source\":{},\"_id\":\"1\",\"_score\":1.2068367}\n", + result + ); + } + + public void testRAGToolWithNeuralSparseQueryAndLLM_withIllegalSourceField_thenGetEmptySource() { + String agentId = createAgent(registerAgentWithNeuralSparseQueryAndLLMRequestBody.replace("text", "text2")); + String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"); + assertEquals(mockLLMResponseWithoutSource, result); + } + + public void testRAGToolWithNeuralQuery_withIllegalSourceField_thenGetEmptySource() { + String agentId = createAgent(registerAgentWithNeuralQueryRequestBody.replace("text", "text2")); + String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"); + + // To allow digits variation from model output, using string contains to match + assertTrue(result.contains("{\"_index\":\"test_neural_index\",\"_source\":{},\"_id\":\"0\",\"_score\":0.70493")); + assertTrue(result.contains("{\"_index\":\"test_neural_index\",\"_source\":{},\"_id\":\"1\",\"_score\":0.26505")); + + } + + public void testRAGToolWithNeuralQueryAndLLM_withIllegalSourceField_thenGetEmptySource() { + String agentId = createAgent(registerAgentWithNeuralQueryAndLLMRequestBody.replace("text", "text2")); + String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"); + assertEquals(mockLLMResponseWithoutSource, result); + } + + public void testRAGToolWithNeuralSparseQuery_withIllegalEmbeddingField_thenThrowException() { + String agentId = createAgent(registerAgentWithNeuralSparseQueryRequestBody.replace("\"embedding\"", "\"embedding2\"")); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf( + containsString("[neural_sparse] query only works on [rank_features] fields"), + containsString("IllegalArgumentException") + ) + ); + } + + public void testRAGToolWithNeuralSparseQueryAndLLM_withIllegalEmbeddingField_thenThrowException() { + String agentId = createAgent(registerAgentWithNeuralSparseQueryAndLLMRequestBody.replace("\"embedding\"", "\"embedding2\"")); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf( + containsString("[neural_sparse] query only works on [rank_features] fields"), + containsString("IllegalArgumentException") + ) + ); + } + + public void testRAGToolWithNeuralQuery_withIllegalEmbeddingField_thenThrowException() { + String agentId = createAgent(registerAgentWithNeuralQueryRequestBody.replace("\"embedding\"", "\"embedding2\"")); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf(containsString("Field 'embedding2' is not knn_vector type."), containsString("IllegalArgumentException")) + ); + } + + public void testRAGToolWithNeuralQueryAndLLM_withIllegalEmbeddingField_thenThrowException() { + String agentId = createAgent(registerAgentWithNeuralQueryAndLLMRequestBody.replace("\"embedding\"", "\"embedding2\"")); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf(containsString("Field 'embedding2' is not knn_vector type."), containsString("IllegalArgumentException")) + ); + } + + public void testRAGToolWithNeuralSparseQuery_withIllegalIndexField_thenThrowException() { + String agentId = createAgent(registerAgentWithNeuralSparseQueryRequestBody.replace(TEST_NEURAL_SPARSE_INDEX_NAME, "test_index2")); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf(containsString("no such index [test_index2]"), containsString("IndexNotFoundException")) + ); + } + + public void testRAGToolWithNeuralSparseQueryAndLLM_withIllegalIndexField_thenThrowException() { + String agentId = createAgent( + registerAgentWithNeuralSparseQueryAndLLMRequestBody.replace(TEST_NEURAL_SPARSE_INDEX_NAME, "test_index2") + ); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf(containsString("no such index [test_index2]"), containsString("IndexNotFoundException")) + ); + } + + public void testRAGToolWithNeuralQuery_withIllegalIndexField_thenThrowException() { + String agentId = createAgent(registerAgentWithNeuralQueryRequestBody.replace(TEST_NEURAL_INDEX_NAME, "test_index2")); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf(containsString("no such index [test_index2]"), containsString("IndexNotFoundException")) + ); + } + + public void testRAGToolWithNeuralQueryAndLLM_withIllegalIndexField_thenThrowException() { + String agentId = createAgent(registerAgentWithNeuralQueryAndLLMRequestBody.replace(TEST_NEURAL_INDEX_NAME, "test_index2")); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf(containsString("no such index [test_index2]"), containsString("IndexNotFoundException")) + ); + } + + public void testRAGToolWithNeuralSparseQuery_withIllegalModelIdField_thenThrowException() { + String agentId = createAgent(registerAgentWithNeuralSparseQueryRequestBody.replace(sparseEncodingModelId, "test_model_id")); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + + org.hamcrest.MatcherAssert + .assertThat(exception.getMessage(), allOf(containsString("Failed to find model"), containsString("OpenSearchStatusException"))); + } + + public void testRAGToolWithNeuralSparseQueryAndLLM_withIllegalModelIdField_thenThrowException() { + String agentId = createAgent(registerAgentWithNeuralSparseQueryAndLLMRequestBody.replace(sparseEncodingModelId, "test_model_id")); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + + org.hamcrest.MatcherAssert + .assertThat(exception.getMessage(), allOf(containsString("Failed to find model"), containsString("OpenSearchStatusException"))); + } + + public void testRAGToolWithNeuralQuery_withIllegalModelIdField_thenThrowException() { + String agentId = createAgent(registerAgentWithNeuralQueryRequestBody.replace(textEmbeddingModelId, "test_model_id")); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + + org.hamcrest.MatcherAssert + .assertThat(exception.getMessage(), allOf(containsString("Failed to find model"), containsString("OpenSearchStatusException"))); + } + + public void testRAGToolWithNeuralQueryAndLLM_withIllegalModelIdField_thenThrowException() { + String agentId = createAgent(registerAgentWithNeuralQueryAndLLMRequestBody.replace(textEmbeddingModelId, "test_model_id")); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + + org.hamcrest.MatcherAssert + .assertThat(exception.getMessage(), allOf(containsString("Failed to find model"), containsString("OpenSearchStatusException"))); + } + + @Override + List promptHandlers() { + PromptHandler RAGToolHandler = new PromptHandler() { + @Override + String response(String prompt) { + String expectPromptForNeuralSparseQuery = "\n" + + "\nHuman:You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say don't know. \n" + + "\n" + + " Context:\n" + + "\"_id: 1\\n_source: {\\\"text\\\":\\\"a b\\\"}\\n\"\n" + + "\n" + + "Human:use RAGTool to answer a\n" + + "\n" + + "Assistant:"; + String expectPromptForNeuralQuery = "\n" + + "\n" + + "Human:You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say don't know. \n" + + "\n" + + " Context:\n" + + "\"_id: 1\\n_source: {\\\"text\\\":\\\"a b\\\"}\\n_id: 0\\n_source: {\\\"text\\\":\\\"hello world\\\"}\\n\"\n" + + "\n" + + "Human:use RAGTool to answer a\n" + + "\n" + + "Assistant:"; + if (prompt.equals(expectPromptForNeuralSparseQuery) || prompt.equals(expectPromptForNeuralQuery)) { + return mockLLMResponseWithSource; + } else { + return mockLLMResponseWithoutSource; + } + } + + @Override + boolean apply(String prompt) { + return true; + } + }; + return List.of(RAGToolHandler); + } + + @Override + String toolType() { + return RAGTool.TYPE; + } +} diff --git a/src/test/java/org/opensearch/integTest/SearchAlertsToolIT.java b/src/test/java/org/opensearch/integTest/SearchAlertsToolIT.java new file mode 100644 index 00000000..95872d22 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/SearchAlertsToolIT.java @@ -0,0 +1,166 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import java.nio.file.Files; +import java.nio.file.Path; + +import org.junit.After; +import org.junit.Before; +import org.junit.jupiter.api.BeforeEach; +import org.opensearch.agent.tools.utils.ToolConstants; + +import com.google.gson.Gson; +import com.google.gson.JsonObject; + +import lombok.SneakyThrows; + +public class SearchAlertsToolIT extends BaseAgentToolsIT { + private String registerAgentRequestBody; + private String alertsIndexMappings; + private String alertingConfigIndexMappings; + private String sampleAlert; + private static final String monitorId = "foo-id"; + private static final String monitorName = "foo-name"; + private static final String registerAgentFilepath = + "org/opensearch/agent/tools/alerting/register_flow_agent_of_search_alerts_tool_request_body.json"; + private static final String alertsIndexMappingsFilepath = "org/opensearch/agent/tools/alerting/alert_index_mappings.json"; + private static final String alertingConfigIndexMappingsFilepath = + "org/opensearch/agent/tools/alerting/alerting_config_index_mappings.json"; + private static final String sampleAlertFilepath = "org/opensearch/agent/tools/alerting/sample_alert.json"; + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + registerAgentRequestBody = Files.readString(Path.of(this.getClass().getClassLoader().getResource(registerAgentFilepath).toURI())); + alertsIndexMappings = Files.readString(Path.of(this.getClass().getClassLoader().getResource(alertsIndexMappingsFilepath).toURI())); + alertingConfigIndexMappings = Files + .readString(Path.of(this.getClass().getClassLoader().getResource(alertingConfigIndexMappingsFilepath).toURI())); + sampleAlert = Files.readString(Path.of(this.getClass().getClassLoader().getResource(sampleAlertFilepath).toURI())); + } + + @BeforeEach + @SneakyThrows + public void prepareTest() { + deleteSystemIndices(); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteExternalIndices(); + deleteSystemIndices(); + } + + @SneakyThrows + public void testSearchAlertsToolInFlowAgent_withNoSystemIndex() { + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{}}"; + String result = executeAgent(agentId, agentInput); + assertEquals("Alerts=[]TotalAlerts=0", result); + } + + @SneakyThrows + public void testSearchAlertsToolInFlowAgent_withSystemIndex() { + setupAlertingSystemIndices(); + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{}}"; + String result = executeAgent(agentId, agentInput); + assertEquals("Alerts=[]TotalAlerts=0", result); + } + + @SneakyThrows + public void testSearchAlertsToolInFlowAgent_singleAlert_noFilter() { + setupAlertingSystemIndices(); + ingestSampleAlert(monitorId, "1"); + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{}}"; + String result = executeAgent(agentId, agentInput); + assertTrue(result.contains("TotalAlerts=1")); + } + + @SneakyThrows + public void testSearchAlertsToolInFlowAgent_singleAlert_filter_match() { + setupAlertingSystemIndices(); + ingestSampleAlert(monitorId, "1"); + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{\"monitorId\": \"" + monitorId + "\"}}"; + String result = executeAgent(agentId, agentInput); + assertTrue(result.contains("TotalAlerts=1")); + } + + @SneakyThrows + public void testSearchAlertsToolInFlowAgent_singleAlert_filter_noMatch() { + setupAlertingSystemIndices(); + ingestSampleAlert(monitorId, "1"); + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{\"monitorId\": \"" + monitorId + "foo" + "\"}}"; + String result = executeAgent(agentId, agentInput); + assertTrue(result.contains("TotalAlerts=0")); + } + + @SneakyThrows + public void testSearchAlertsToolInFlowAgent_multipleAlerts_noFilter() { + setupAlertingSystemIndices(); + ingestSampleAlert(monitorId, "1"); + ingestSampleAlert(monitorId + "foo", "2"); + ingestSampleAlert(monitorId + "bar", "3"); + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{}}"; + String result = executeAgent(agentId, agentInput); + assertTrue(result.contains("TotalAlerts=3")); + } + + @SneakyThrows + public void testSearchAlertsToolInFlowAgent_multipleAlerts_filter() { + setupAlertingSystemIndices(); + ingestSampleAlert(monitorId, "1"); + ingestSampleAlert(monitorId + "foo", "2"); + ingestSampleAlert(monitorId + "bar", "3"); + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{\"monitorId\": \"" + monitorId + "\"}}"; + String result = executeAgent(agentId, agentInput); + assertTrue(result.contains("TotalAlerts=1")); + } + + @SneakyThrows + public void testSearchAlertsToolInFlowAgent_multipleAlerts_complexParams() { + setupAlertingSystemIndices(); + String monitorId2 = monitorId + "2"; + String monitorId3 = monitorId + "3"; + ingestSampleAlert(monitorId, "1"); + ingestSampleAlert(monitorId2, "2"); + ingestSampleAlert(monitorId3, "3"); + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{\"monitorIds\": " + + "[ \"" + + monitorId + + "\", \"" + + monitorId2 + + "\", \"" + + monitorId3 + + "\" ], " + + "\"sortOrder\": \"asc\", \"sortString\": \"monitor_name.keyword\", \"size\": 3, \"startIndex\": 0, \"severityLevel\": \"ALL\", \"alertState\": \"ALL\" } }"; + + String result = executeAgent(agentId, agentInput); + assertTrue(result.contains("TotalAlerts=3")); + } + + @SneakyThrows + private void setupAlertingSystemIndices() { + createIndexWithConfiguration(ToolConstants.ALERTING_ALERTS_INDEX, alertsIndexMappings); + createIndexWithConfiguration(ToolConstants.ALERTING_CONFIG_INDEX, alertingConfigIndexMappings); + } + + private void ingestSampleAlert(String monitorId, String docId) { + JsonObject sampleAlertJson = new Gson().fromJson(sampleAlert, JsonObject.class); + sampleAlertJson.addProperty("monitor_id", monitorId); + addDocToIndex(ToolConstants.ALERTING_ALERTS_INDEX, docId, sampleAlertJson.toString()); + } + +} diff --git a/src/test/java/org/opensearch/integTest/SearchAnomalyDetectorsToolIT.java b/src/test/java/org/opensearch/integTest/SearchAnomalyDetectorsToolIT.java new file mode 100644 index 00000000..eb0c529d --- /dev/null +++ b/src/test/java/org/opensearch/integTest/SearchAnomalyDetectorsToolIT.java @@ -0,0 +1,275 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Locale; + +import org.junit.After; +import org.junit.Before; +import org.junit.jupiter.api.MethodOrderer.OrderAnnotation; +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.TestMethodOrder; + +import com.google.gson.Gson; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; + +import lombok.SneakyThrows; + +@TestMethodOrder(OrderAnnotation.class) +public class SearchAnomalyDetectorsToolIT extends BaseAgentToolsIT { + private String registerAgentRequestBody; + private String sampleDetector; + private String sampleIndexMappings; + private static final String detectorName = "foo-name"; + private static final String registerAgentFilepath = + "org/opensearch/agent/tools/anomaly-detection/register_flow_agent_of_search_anomaly_detectors_tool_request_body.json"; + private static final String sampleDetectorFilepath = "org/opensearch/agent/tools/anomaly-detection/sample_detector.json"; + private static final String sampleIndexMappingsFilepath = "org/opensearch/agent/tools/anomaly-detection/sample_index_mappings.json"; + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + registerAgentRequestBody = Files.readString(Path.of(this.getClass().getClassLoader().getResource(registerAgentFilepath).toURI())); + sampleDetector = Files.readString(Path.of(this.getClass().getClassLoader().getResource(sampleDetectorFilepath).toURI())); + sampleIndexMappings = Files.readString(Path.of(this.getClass().getClassLoader().getResource(sampleIndexMappingsFilepath).toURI())); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteExternalIndices(); + } + + @SneakyThrows + @Order(1) + public void testSearchAnomalyDetectorsToolInFlowAgent_withNoSystemIndex() { + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{\"detectorName\": \"" + detectorName + "\"}}"; + String result = executeAgent(agentId, agentInput); + assertEquals("AnomalyDetectors=[]TotalAnomalyDetectors=0", result); + } + + @SneakyThrows + @Order(2) + public void testSearchAnomalyDetectorsToolInFlowAgent_detectorNameParam() { + String detectorId = null; + try { + setupTestDetectionIndex("test-index"); + detectorId = ingestSampleDetector(detectorName, "test-index"); + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{\"detectorName\": \"" + detectorName + "foo" + "\"}}"; + String result = executeAgent(agentId, agentInput); + assertEquals("AnomalyDetectors=[]TotalAnomalyDetectors=0", result); + + String agentInput2 = "{\"parameters\":{\"detectorName\": \"" + detectorName + "\"}}"; + String result2 = executeAgent(agentId, agentInput2); + assertTrue(result2.contains(String.format(Locale.ROOT, "id=%s", detectorId))); + assertTrue(result2.contains(String.format(Locale.ROOT, "name=%s", detectorName))); + assertTrue(result2.contains(String.format(Locale.ROOT, "TotalAnomalyDetectors=%d", 1))); + } finally { + if (detectorId != null) { + deleteDetector(detectorId); + } + } + } + + @SneakyThrows + @Order(3) + public void testSearchAnomalyDetectorsToolInFlowAgent_detectorNamePatternParam() { + String detectorId = null; + try { + setupTestDetectionIndex("test-index"); + detectorId = ingestSampleDetector(detectorName, "test-index"); + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{\"detectorNamePattern\": \"" + detectorName + "foo" + "\"}}"; + String result = executeAgent(agentId, agentInput); + assertEquals("AnomalyDetectors=[]TotalAnomalyDetectors=0", result); + + String agentInput2 = "{\"parameters\":{\"detectorNamePattern\": \"" + detectorName + "*" + "\"}}"; + String result2 = executeAgent(agentId, agentInput2); + assertTrue(result2.contains(String.format(Locale.ROOT, "id=%s", detectorId))); + assertTrue(result2.contains(String.format(Locale.ROOT, "name=%s", detectorName))); + assertTrue(result2.contains(String.format(Locale.ROOT, "TotalAnomalyDetectors=%d", 1))); + } finally { + if (detectorId != null) { + deleteDetector(detectorId); + } + } + + } + + @SneakyThrows + @Order(4) + public void testSearchAnomalyDetectorsToolInFlowAgent_indicesParam() { + String detectorId = null; + try { + setupTestDetectionIndex("test-index"); + detectorId = ingestSampleDetector(detectorName, "test-index"); + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{\"indices\": \"test-index-foo\"}}"; + String result = executeAgent(agentId, agentInput); + assertEquals("AnomalyDetectors=[]TotalAnomalyDetectors=0", result); + + String agentInput2 = "{\"parameters\":{\"indices\": \"test-index\"}}"; + String result2 = executeAgent(agentId, agentInput2); + assertTrue(result2.contains(String.format(Locale.ROOT, "TotalAnomalyDetectors=%d", 1))); + } finally { + if (detectorId != null) { + deleteDetector(detectorId); + } + } + + } + + @SneakyThrows + @Order(5) + public void testSearchAnomalyDetectorsToolInFlowAgent_highCardinalityParam() { + String detectorId = null; + try { + setupTestDetectionIndex("test-index"); + detectorId = ingestSampleDetector(detectorName, "test-index"); + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{\"highCardinality\": \"true\"}}"; + String result = executeAgent(agentId, agentInput); + assertEquals("AnomalyDetectors=[]TotalAnomalyDetectors=0", result); + + String agentInput2 = "{\"parameters\":{\"highCardinality\": \"false\"}}"; + String result2 = executeAgent(agentId, agentInput2); + assertTrue(result2.contains(String.format(Locale.ROOT, "id=%s", detectorId))); + assertTrue(result2.contains(String.format(Locale.ROOT, "name=%s", detectorName))); + assertTrue(result2.contains(String.format(Locale.ROOT, "TotalAnomalyDetectors=%d", 1))); + } finally { + if (detectorId != null) { + deleteDetector(detectorId); + } + } + + } + + @SneakyThrows + @Order(6) + public void testSearchAnomalyDetectorsToolInFlowAgent_detectorStateParams() { + String detectorIdRunning = null; + String detectorIdDisabled1 = null; + String detectorIdDisabled2 = null; + try { + // TODO: update test scenarios + setupTestDetectionIndex("test-index"); + detectorIdRunning = ingestSampleDetector(detectorName + "-running", "test-index"); + detectorIdDisabled1 = ingestSampleDetector(detectorName + "-disabled-1", "test-index"); + detectorIdDisabled2 = ingestSampleDetector(detectorName + "-disabled-2", "test-index"); + startDetector(detectorIdRunning); + Thread.sleep(5000); + + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{\"running\": \"true\"}}"; + String result = executeAgent(agentId, agentInput); + assertTrue(result.contains(String.format(Locale.ROOT, "TotalAnomalyDetectors=%d", 1))); + assertTrue(result.contains(detectorIdRunning)); + + String agentInput2 = "{\"parameters\":{\"running\": \"false\"}}"; + String result2 = executeAgent(agentId, agentInput2); + assertTrue(result2.contains(String.format(Locale.ROOT, "TotalAnomalyDetectors=%d", 2))); + assertTrue(result2.contains(detectorIdDisabled1)); + assertTrue(result2.contains(detectorIdDisabled2)); + + String agentInput3 = "{\"parameters\":{\"failed\": \"true\"}}"; + String result3 = executeAgent(agentId, agentInput3); + assertTrue(result3.contains(String.format(Locale.ROOT, "TotalAnomalyDetectors=%d", 0))); + + String agentInput4 = "{\"parameters\":{\"failed\": \"false\"}}"; + String result4 = executeAgent(agentId, agentInput4); + assertTrue(result4.contains(String.format(Locale.ROOT, "TotalAnomalyDetectors=%d", 3))); + assertTrue(result4.contains(detectorIdRunning)); + assertTrue(result4.contains(detectorIdDisabled1)); + assertTrue(result4.contains(detectorIdDisabled2)); + + String agentInput5 = "{\"parameters\":{\"running\": \"true\", \"failed\": \"true\"}}"; + String result5 = executeAgent(agentId, agentInput5); + assertTrue(result5.contains(String.format(Locale.ROOT, "TotalAnomalyDetectors=%d", 1))); + assertTrue(result5.contains(detectorIdRunning)); + + String agentInput6 = "{\"parameters\":{\"running\": \"true\", \"failed\": \"false\"}}"; + String result6 = executeAgent(agentId, agentInput6); + assertTrue(result6.contains(String.format(Locale.ROOT, "TotalAnomalyDetectors=%d", 1))); + assertTrue(result6.contains(detectorIdRunning)); + + String agentInput7 = "{\"parameters\":{\"running\": \"false\", \"failed\": \"true\"}}"; + String result7 = executeAgent(agentId, agentInput7); + assertTrue(result7.contains(String.format(Locale.ROOT, "TotalAnomalyDetectors=%d", 2))); + assertTrue(result7.contains(detectorIdDisabled1)); + assertTrue(result7.contains(detectorIdDisabled2)); + + String agentInput8 = "{\"parameters\":{\"running\": \"false\", \"failed\": \"false\"}}"; + String result8 = executeAgent(agentId, agentInput8); + assertTrue(result8.contains(String.format(Locale.ROOT, "TotalAnomalyDetectors=%d", 2))); + assertTrue(result8.contains(detectorIdDisabled1)); + assertTrue(result8.contains(detectorIdDisabled2)); + } finally { + if (detectorIdRunning != null) { + stopDetector(detectorIdRunning); + Thread.sleep(5000); + deleteDetector(detectorIdRunning); + } + if (detectorIdDisabled1 != null) { + deleteDetector(detectorIdDisabled1); + } + if (detectorIdDisabled2 != null) { + deleteDetector(detectorIdDisabled2); + } + } + + } + + @SneakyThrows + @Order(7) + public void testSearchAnomalyDetectorsToolInFlowAgent_complexParams() { + String detectorId = null; + String detectorIdFoo = null; + try { + setupTestDetectionIndex("test-index"); + detectorId = ingestSampleDetector(detectorName, "test-index"); + detectorIdFoo = ingestSampleDetector(detectorName + "foo", "test-index"); + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{\"detectorName\": \"" + + detectorName + + "\", \"highCardinality\": false, \"sortOrder\": \"asc\", \"sortString\": \"name.keyword\", \"size\": 10, \"startIndex\": 0 }}"; + String result = executeAgent(agentId, agentInput); + assertTrue(result.contains(String.format(Locale.ROOT, "id=%s", detectorId))); + assertTrue(result.contains(String.format(Locale.ROOT, "name=%s", detectorName))); + assertTrue(result.contains(String.format(Locale.ROOT, "TotalAnomalyDetectors=%d", 1))); + } finally { + if (detectorId != null) { + deleteDetector(detectorId); + } + if (detectorIdFoo != null) { + deleteDetector(detectorIdFoo); + } + } + } + + @SneakyThrows + private void setupTestDetectionIndex(String indexName) { + createIndexWithConfiguration(indexName, sampleIndexMappings); + addDocToIndex(indexName, "foo-id", List.of("timestamp", "value"), List.of(1234, 1)); + } + + private String ingestSampleDetector(String detectorName, String detectionIndex) { + JsonObject sampleDetectorJson = new Gson().fromJson(sampleDetector, JsonObject.class); + JsonArray arr = new JsonArray(1); + arr.add(detectionIndex); + sampleDetectorJson.addProperty("name", detectorName); + sampleDetectorJson.remove("indices"); + sampleDetectorJson.add("indices", arr); + return indexDetector(sampleDetectorJson.toString()); + } +} diff --git a/src/test/java/org/opensearch/integTest/SearchAnomalyResultsToolIT.java b/src/test/java/org/opensearch/integTest/SearchAnomalyResultsToolIT.java new file mode 100644 index 00000000..46234ea6 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/SearchAnomalyResultsToolIT.java @@ -0,0 +1,138 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Locale; + +import org.junit.After; +import org.junit.Before; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.MethodOrderer.OrderAnnotation; +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.TestMethodOrder; +import org.opensearch.agent.tools.utils.ToolConstants; + +import com.google.gson.Gson; +import com.google.gson.JsonObject; + +import lombok.SneakyThrows; + +@TestMethodOrder(OrderAnnotation.class) +public class SearchAnomalyResultsToolIT extends BaseAgentToolsIT { + private String registerAgentRequestBody; + private String resultsIndexMappings; + private String sampleResult; + private static final String detectorId = "foo-id"; + private static final double anomalyGrade = 0.5; + private static final double confidence = 0.6; + private static final String resultsSystemIndexName = ".opendistro-anomaly-results-1"; + private static final String registerAgentFilepath = + "org/opensearch/agent/tools/anomaly-detection/register_flow_agent_of_search_anomaly_results_tool_request_body.json"; + private static final String resultsIndexMappingsFilepath = "org/opensearch/agent/tools/anomaly-detection/results_index_mappings.json"; + private static final String sampleResultFilepath = "org/opensearch/agent/tools/anomaly-detection/sample_result.json"; + + @Before + @SneakyThrows + public void setUp() { + deleteExternalIndices(); + deleteSystemIndices(); + super.setUp(); + registerAgentRequestBody = Files.readString(Path.of(this.getClass().getClassLoader().getResource(registerAgentFilepath).toURI())); + resultsIndexMappings = Files + .readString(Path.of(this.getClass().getClassLoader().getResource(resultsIndexMappingsFilepath).toURI())); + sampleResult = Files.readString(Path.of(this.getClass().getClassLoader().getResource(sampleResultFilepath).toURI())); + } + + @BeforeEach + @SneakyThrows + public void prepareTest() { + deleteSystemIndices(); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteExternalIndices(); + deleteSystemIndices(); + } + + @SneakyThrows + @Order(1) + public void testSearchAnomalyResultsToolInFlowAgent_withNoSystemIndex() { + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{\"detectorId\": \"" + detectorId + "\"}}"; + String result = executeAgent(agentId, agentInput); + assertEquals("AnomalyResults=[]TotalAnomalyResults=0", result); + } + + @SneakyThrows + @Order(2) + public void testSearchAnomalyResultsToolInFlowAgent_noMatching() { + setupADSystemIndices(); + ingestSampleResult(detectorId, 0.5, 0.5, "1"); + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{\"detectorId\": \"" + detectorId + "foo" + "\"}}"; + String result = executeAgent(agentId, agentInput); + assertEquals("AnomalyResults=[]TotalAnomalyResults=0", result); + } + + @SneakyThrows + @Order(3) + public void testSearchAnomalyResultsToolInFlowAgent_matching() { + setupADSystemIndices(); + ingestSampleResult(detectorId, anomalyGrade, confidence, "1"); + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{\"detectorId\": \"" + detectorId + "\"}}"; + String result = executeAgent(agentId, agentInput); + assertEquals( + String + .format( + Locale.ROOT, + "AnomalyResults=[{detectorId=%s,grade=%2.1f,confidence=%2.1f}]TotalAnomalyResults=%d", + detectorId, + anomalyGrade, + confidence, + 1 + ), + result + ); + } + + @SneakyThrows + @Order(4) + public void testSearchAnomalyResultsToolInFlowAgent_complexParams() { + setupADSystemIndices(); + ingestSampleResult(detectorId, anomalyGrade, confidence, "1"); + ingestSampleResult(detectorId + "foo", anomalyGrade, confidence, "2"); + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{\"detectorId\": \"" + + detectorId + + "\"," + + "\"realTime\": true, \"anomalyGradeThreshold\": 0, \"sortOrder\": \"asc\"," + + "\"sortString\": \"data_start_time\", \"size\": 10, \"startIndex\": 0 }}"; + String result = executeAgent(agentId, agentInput); + assertTrue( + String.format(Locale.ROOT, "total anomaly results is not 1, result: %s", result), + result.contains(String.format(Locale.ROOT, "TotalAnomalyResults=%d", 1)) + ); + } + + @SneakyThrows + private void setupADSystemIndices() { + createIndexWithConfiguration(ToolConstants.AD_RESULTS_INDEX, resultsIndexMappings); + } + + private void ingestSampleResult(String detectorId, double anomalyGrade, double anomalyConfidence, String docId) { + JsonObject sampleResultJson = new Gson().fromJson(sampleResult, JsonObject.class); + sampleResultJson.addProperty("detector_id", detectorId); + sampleResultJson.addProperty("anomaly_grade", anomalyGrade); + sampleResultJson.addProperty("confidence", confidence); + addDocToIndex(ToolConstants.AD_RESULTS_INDEX, docId, sampleResultJson.toString()); + } +} diff --git a/src/test/java/org/opensearch/integTest/SearchMonitorsToolIT.java b/src/test/java/org/opensearch/integTest/SearchMonitorsToolIT.java new file mode 100644 index 00000000..b0ee3503 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/SearchMonitorsToolIT.java @@ -0,0 +1,153 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Locale; + +import org.junit.After; +import org.junit.Before; +import org.junit.jupiter.api.MethodOrderer.OrderAnnotation; +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.TestMethodOrder; + +import com.google.gson.Gson; +import com.google.gson.JsonObject; + +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@TestMethodOrder(OrderAnnotation.class) +public class SearchMonitorsToolIT extends BaseAgentToolsIT { + private String registerAgentRequestBody; + private String sampleMonitor; + private static final String monitorName = "foo-name"; + private static final String monitorName2 = "bar-name"; + private static final String registerAgentFilepath = + "org/opensearch/agent/tools/alerting/register_flow_agent_of_search_monitors_tool_request_body.json"; + private static final String sampleMonitorFilepath = "org/opensearch/agent/tools/alerting/sample_monitor.json"; + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + registerAgentRequestBody = Files.readString(Path.of(this.getClass().getClassLoader().getResource(registerAgentFilepath).toURI())); + sampleMonitor = Files.readString(Path.of(this.getClass().getClassLoader().getResource(sampleMonitorFilepath).toURI())); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteExternalIndices(); + } + + @SneakyThrows + @Order(1) + public void testSearchMonitorsToolInFlowAgent_withNoSystemIndex() { + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{\"monitorName\": \"" + monitorName + "\"}}"; + String result = executeAgent(agentId, agentInput); + assertEquals("Monitors=[]TotalMonitors=0", result); + } + + @SneakyThrows + @Order(2) + public void testSearchMonitorsToolInFlowAgent_searchById() { + String monitorId = ingestSampleMonitor(monitorName, true); + + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{\"monitorId\": \"" + monitorId + "\"}}"; + String result = executeAgent(agentId, agentInput); + assertTrue(result.contains(String.format(Locale.ROOT, "name=%s", monitorName))); + assertTrue(result.contains("TotalMonitors=1")); + deleteMonitor(monitorId); + } + + @SneakyThrows + @Order(3) + public void testSearchMonitorsToolInFlowAgent_singleMonitor_noFilter() { + String monitorId = ingestSampleMonitor(monitorName, true); + + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{}}"; + String result = executeAgent(agentId, agentInput); + assertTrue(result.contains(String.format(Locale.ROOT, "name=%s", monitorName))); + assertTrue(result.contains("TotalMonitors=1")); + deleteMonitor(monitorId); + } + + @SneakyThrows + @Order(4) + public void testSearchMonitorsToolInFlowAgent_singleMonitor_filter() { + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{\"monitorId\": \"" + "foo-id" + "\"}}"; + String result = executeAgent(agentId, agentInput); + assertTrue(result.contains("TotalMonitors=0")); + } + + @SneakyThrows + @Order(5) + public void testSearchMonitorsToolInFlowAgent_multipleMonitors_noFilter() { + String monitorId1 = ingestSampleMonitor(monitorName, true); + String monitorId2 = ingestSampleMonitor(monitorName2, false); + + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{}}"; + String result = executeAgent(agentId, agentInput); + assertTrue(result.contains(String.format(Locale.ROOT, "name=%s", monitorName))); + assertTrue(result.contains(String.format(Locale.ROOT, "name=%s", monitorName2))); + assertTrue(result.contains("enabled=true")); + assertTrue(result.contains("enabled=false")); + assertTrue(result.contains("TotalMonitors=2")); + deleteMonitor(monitorId1); + deleteMonitor(monitorId2); + } + + @SneakyThrows + @Order(6) + public void testSearchMonitorsToolInFlowAgent_multipleMonitors_filter() { + String monitorId1 = ingestSampleMonitor(monitorName, true); + String monitorId2 = ingestSampleMonitor(monitorName2, false); + + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{\"monitorName\": \"" + monitorName + "\"}}"; + String result = executeAgent(agentId, agentInput); + assertTrue(result.contains(String.format(Locale.ROOT, "name=%s", monitorName))); + assertFalse(result.contains(String.format(Locale.ROOT, "name=%s", monitorName2))); + assertTrue(result.contains("enabled=true")); + assertTrue(result.contains("TotalMonitors=1")); + deleteMonitor(monitorId1); + deleteMonitor(monitorId2); + } + + @SneakyThrows + @Order(7) + public void testSearchMonitorsToolInFlowAgent_multipleMonitors_complexParams() { + String monitorId1 = ingestSampleMonitor(monitorName, true); + String monitorId2 = ingestSampleMonitor(monitorName2, false); + + String agentId = createAgent(registerAgentRequestBody); + String agentInput = "{\"parameters\":{\"monitorName\": \"" + + monitorName + + "\", \"enabled\": true, \"hasTriggers\": false, \"sortOrder\": \"asc\", \"sortString\": \"monitor.name.keyword\", \"size\": 10, \"startIndex\": 0 }}"; + String result = executeAgent(agentId, agentInput); + assertTrue(result.contains("TotalMonitors=1")); + deleteMonitor(monitorId1); + deleteMonitor(monitorId2); + } + + private String ingestSampleMonitor(String monitorName, boolean enabled) { + JsonObject sampleMonitorJson = new Gson().fromJson(sampleMonitor, JsonObject.class); + sampleMonitorJson.addProperty("name", monitorName); + sampleMonitorJson.addProperty("enabled", String.valueOf(enabled)); + return indexMonitor(sampleMonitorJson.toString()); + } +} diff --git a/src/test/java/org/opensearch/integTest/ToolIntegrationTest.java b/src/test/java/org/opensearch/integTest/ToolIntegrationTest.java new file mode 100644 index 00000000..315812b1 --- /dev/null +++ b/src/test/java/org/opensearch/integTest/ToolIntegrationTest.java @@ -0,0 +1,235 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import java.io.IOException; +import java.io.InputStream; +import java.util.List; +import java.util.Locale; +import java.util.UUID; +import java.util.concurrent.TimeUnit; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.client.Request; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.Response; + +import com.google.gson.Gson; +import com.google.gson.JsonParser; +import com.sun.net.httpserver.HttpServer; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public abstract class ToolIntegrationTest extends BaseAgentToolsIT { + protected HttpServer server; + protected String modelId; + protected String agentId; + protected String modelGroupId; + protected String connectorId; + + private final Gson gson = new Gson(); + + abstract List promptHandlers(); + + abstract String toolType(); + + @Before + public void setupTestAgent() throws IOException, InterruptedException { + server = MockHttpServer.setupMockLLM(promptHandlers()); + server.start(); + clusterSettings(false); + connectorId = setUpConnectorWithRetry(5); + modelGroupId = setupModelGroup(); + modelId = setupLLMModel(connectorId, modelGroupId); + // wait for model to get deployed + TimeUnit.SECONDS.sleep(1); + agentId = setupConversationalAgent(modelId); + log.info("model_id: {}, agent_id: {}", modelId, agentId); + } + + @After + public void cleanUpClusterSetting() throws IOException { + clusterSettings(true); + } + + @After + public void stopMockLLM() { + server.stop(1); + } + + @After + public void deleteModel() { + deleteModel(modelId); + } + + private String setUpConnectorWithRetry(int maxRetryTimes) throws InterruptedException { + int retryTimes = 0; + String connectorId = null; + while (retryTimes < maxRetryTimes) { + try { + connectorId = setUpConnector(); + break; + } catch (Exception e) { + // Wait for ML encryption master key has been initialized + log.info("Failed to setup connector, retry times: {}", retryTimes); + retryTimes++; + TimeUnit.SECONDS.sleep(10); + } + } + return connectorId; + } + + private String setUpConnector() { + String url = String.format(Locale.ROOT, "http://127.0.0.1:%d/invoke", server.getAddress().getPort()); + return createConnector( + "{\n" + + " \"name\": \"BedRock test claude Connector\",\n" + + " \"description\": \"The connector to BedRock service for claude model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws_sigv4\",\n" + + " \"parameters\": {\n" + + " \"region\": \"us-east-1\",\n" + + " \"service_name\": \"bedrock\",\n" + + " \"anthropic_version\": \"bedrock-2023-05-31\",\n" + + " \"endpoint\": \"bedrock.us-east-1.amazonaws.com\",\n" + + " \"auth\": \"Sig_V4\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens_to_sample\": 8000,\n" + + " \"temperature\": 0.0001,\n" + + " \"response_filter\": \"$.completion\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"\",\n" + + " \"secret_key\": \"\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"" + + url + + "\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\",\n" + + " \"x-amz-content-sha256\": \"required\"\n" + + " },\n" + + " \"request_body\": \"{\\\"prompt\\\":\\\"${parameters.prompt}\\\", \\\"max_tokens_to_sample\\\":${parameters.max_tokens_to_sample}, \\\"temperature\\\":${parameters.temperature}, \\\"anthropic_version\\\":\\\"${parameters.anthropic_version}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}" + ); + } + + private void clusterSettings(boolean clean) throws IOException { + if (!clean) { + updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false); + updateClusterSettings("plugins.ml_commons.memory_feature_enabled", true); + updateClusterSettings("plugins.ml_commons.trusted_connector_endpoints_regex", List.of("^.*$")); + } else { + updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", null); + updateClusterSettings("plugins.ml_commons.memory_feature_enabled", null); + updateClusterSettings("plugins.ml_commons.trusted_connector_endpoints_regex", null); + } + } + + private String setupModelGroup() throws IOException { + Request request = new Request("POST", "/_plugins/_ml/model_groups/_register"); + request + .setJsonEntity( + "{\n" + + " \"name\": \"test_model_group_bedrock-" + + UUID.randomUUID() + + "\",\n" + + " \"description\": \"This is a public model group\"\n" + + "}" + ); + Response response = executeRequest(request); + + String resp = readResponse(response); + + return JsonParser.parseString(resp).getAsJsonObject().get("model_group_id").getAsString(); + } + + private String setupLLMModel(String connectorId, String modelGroupId) throws IOException { + Request request = new Request("POST", "/_plugins/_ml/models/_register?deploy=true"); + request + .setJsonEntity( + "{\n" + + " \"name\": \"Bedrock Claude V2 model\",\n" + + " \"function_name\": \"remote\",\n" + + " \"model_group_id\": \"" + + modelGroupId + + "\",\n" + + " \"description\": \"test model\",\n" + + " \"connector_id\": \"" + + connectorId + + "\"\n" + + "}" + ); + Response response = executeRequest(request); + + String resp = readResponse(response); + + return JsonParser.parseString(resp).getAsJsonObject().get("model_id").getAsString(); + } + + private String setupConversationalAgent(String modelId) throws IOException { + Request request = new Request("POST", "/_plugins/_ml/agents/_register"); + request + .setJsonEntity( + "{\n" + + " \"name\": \"integTest-agent\",\n" + + " \"type\": \"conversational\",\n" + + " \"description\": \"this is a test agent\",\n" + + " \"llm\": {\n" + + " \"model_id\": \"" + + modelId + + "\",\n" + + " \"parameters\": {\n" + + " \"max_iteration\": \"5\",\n" + + " \"stop_when_no_tool_found\": \"true\",\n" + + " \"response_filter\": \"$.completion\"\n" + + " }\n" + + " },\n" + + " \"tools\": [\n" + + " {\n" + + " \"type\": \"" + + toolType() + + "\",\n" + + " \"name\": \"" + + toolType() + + "\",\n" + + " \"include_output_in_agent_response\": true,\n" + + " \"description\": \"tool description\"\n" + + " }\n" + + " ],\n" + + " \"memory\": {\n" + + " \"type\": \"conversation_index\"\n" + + " }\n" + + "}" + ); + Response response = executeRequest(request); + + String resp = readResponse(response); + + return JsonParser.parseString(resp).getAsJsonObject().get("agent_id").getAsString(); + } + + public static Response executeRequest(Request request) throws IOException { + RequestOptions.Builder builder = RequestOptions.DEFAULT.toBuilder(); + builder.addHeader("Content-Type", "application/json"); + request.setOptions(builder); + return client().performRequest(request); + } + + public static String readResponse(Response response) throws IOException { + try (InputStream ins = response.getEntity().getContent()) { + return String.join("", org.opensearch.common.io.Streams.readAllLines(ins)); + } + } +} diff --git a/src/test/java/org/opensearch/integTest/VectorDBToolIT.java b/src/test/java/org/opensearch/integTest/VectorDBToolIT.java new file mode 100644 index 00000000..3f7fc77e --- /dev/null +++ b/src/test/java/org/opensearch/integTest/VectorDBToolIT.java @@ -0,0 +1,279 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.integTest; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertThrows; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.client.ResponseException; + +import lombok.SneakyThrows; + +public class VectorDBToolIT extends BaseAgentToolsIT { + + public static String TEST_INDEX_NAME = "test_index"; + public static String TEST_NESTED_INDEX_NAME = "test_index_nested"; + + private String modelId; + private String registerAgentRequestBody; + + @SneakyThrows + private void prepareModel() { + String requestBody = Files + .readString( + Path + .of( + this + .getClass() + .getClassLoader() + .getResource("org/opensearch/agent/tools/register_text_embedding_model_request_body.json") + .toURI() + ) + ); + modelId = registerModelThenDeploy(requestBody); + } + + @SneakyThrows + private void prepareIndex() { + + String pipelineConfig = "{\n" + + " \"description\": \"text embedding pipeline\",\n" + + " \"processors\": [\n" + + " {\n" + + " \"text_embedding\": {\n" + + " \"model_id\": \"" + + modelId + + "\",\n" + + " \"field_map\": {\n" + + " \"text\": \"embedding\"\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + createIngestPipelineWithConfiguration("test-embedding-model", pipelineConfig); + + String indexMapping = "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"text\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"embedding\": {\n" + + " \"type\": \"knn_vector\",\n" + + " \"dimension\": 768,\n" + + " \"method\": {\n" + + " \"name\": \"hnsw\",\n" + + " \"space_type\": \"l2\",\n" + + " \"engine\": \"lucene\",\n" + + " \"parameters\": {\n" + + " \"ef_construction\": 128,\n" + + " \"m\": 24\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " \"settings\": {\n" + + " \"index\": {\n" + + " \"knn.space_type\": \"cosinesimil\",\n" + + " \"default_pipeline\": \"test-embedding-model\",\n" + + " \"knn\": \"true\"\n" + + " }\n" + + " }\n" + + "}"; + + createIndexWithConfiguration(TEST_INDEX_NAME, indexMapping); + + addDocToIndex(TEST_INDEX_NAME, "0", List.of("text"), List.of("hello world")); + + addDocToIndex(TEST_INDEX_NAME, "1", List.of("text"), List.of("a b")); + } + + @SneakyThrows + private void prepareNestedIndex() { + String pipelineConfig = "{\n" + + " \"description\": \"text embedding pipeline\",\n" + + " \"processors\": [\n" + + " {\n" + + " \"text_embedding\": {\n" + + " \"model_id\": \"" + + modelId + + "\",\n" + + " \"field_map\": {\n" + + " \"text\": \"embedding\"\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + createIngestPipelineWithConfiguration("test-embedding-model", pipelineConfig); + + String indexMapping = "{\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"text\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"embedding\": {\n" + + " \"type\":\"nested\",\n" + + " \"properties\":{\n" + + " \"knn\":{\n" + + " \"type\": \"knn_vector\",\n" + + " \"dimension\": 768,\n" + + " \"method\": {\n" + + " \"name\": \"hnsw\",\n" + + " \"space_type\": \"l2\",\n" + + " \"engine\": \"lucene\",\n" + + " \"parameters\": {\n" + + " \"ef_construction\": 128,\n" + + " \"m\": 24\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " \n" + + " }\n" + + " }\n" + + " },\n" + + " \"settings\": {\n" + + " \"index\": {\n" + + " \"knn.space_type\": \"cosinesimil\",\n" + + " \"default_pipeline\": \"test-embedding-model\",\n" + + " \"knn\": \"true\"\n" + + " }\n" + + " }\n" + + "}"; + + createIndexWithConfiguration(TEST_NESTED_INDEX_NAME, indexMapping); + + addDocToIndex(TEST_NESTED_INDEX_NAME, "0", List.of("text"), List.of(List.of("hello world"))); + + addDocToIndex(TEST_NESTED_INDEX_NAME, "1", List.of("text"), List.of(List.of("a b"))); + } + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + prepareModel(); + prepareIndex(); + prepareNestedIndex(); + registerAgentRequestBody = Files + .readString( + Path + .of( + this + .getClass() + .getClassLoader() + .getResource("org/opensearch/agent/tools/register_flow_agent_of_vectordb_tool_request_body.json") + .toURI() + ) + ); + registerAgentRequestBody = registerAgentRequestBody.replace("", modelId); + + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteExternalIndices(); + deleteModel(modelId); + } + + public void testVectorDBToolInFlowAgent() { + String agentId = createAgent(registerAgentRequestBody); + + // match similar text, doc1 match with higher score + String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"c\"}}"); + + // To allow digits variation from model output, using string contains to match + assertTrue(result.contains("{\"_index\":\"test_index\",\"_source\":{\"text\":\"hello world\"},\"_id\":\"0\",\"_score\":0.70467")); + assertTrue(result.contains("{\"_index\":\"test_index\",\"_source\":{\"text\":\"a b\"},\"_id\":\"1\",\"_score\":0.26499")); + + // match exact same text case, doc0 match with higher score + String result1 = executeAgent(agentId, "{\"parameters\": {\"question\": \"hello\"}}"); + + // To allow digits variation from model output, using string contains to match + assertTrue( + result1.contains("{\"_index\":\"test_index\",\"_source\":{\"text\":\"hello world\"},\"_id\":\"0\",\"_score\":0.5671488") + ); + assertTrue(result1.contains("{\"_index\":\"test_index\",\"_source\":{\"text\":\"a b\"},\"_id\":\"1\",\"_score\":0.2423683")); + + // if blank input, call onFailure and get exception + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"\"}}")); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf(containsString("[input] is null or empty, can not process it."), containsString("IllegalArgumentException")) + ); + } + + public void testVectorDBToolInFlowAgent_withNestedIndex() { + String registerAgentRequestBodyNested = registerAgentRequestBody; + registerAgentRequestBodyNested = registerAgentRequestBodyNested.replace("\"nested_path\": \"\"", "\"nested_path\": \"embedding\""); + registerAgentRequestBodyNested = registerAgentRequestBodyNested + .replace("\"embedding_field\": \"embedding\"", "\"embedding_field\": \"embedding.knn\""); + registerAgentRequestBodyNested = registerAgentRequestBodyNested + .replace("\"index\": \"test_index\"", "\"index\": \"test_index_nested\""); + String agentId = createAgent(registerAgentRequestBodyNested); + String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"); + // To allow digits variation from model output, using string contains to match + assertTrue( + result.contains("{\"_index\":\"test_index_nested\",\"_source\":{\"text\":[\"hello world\"]},\"_id\":\"0\",\"_score\":0.7") + ); + assertTrue(result.contains("{\"_index\":\"test_index_nested\",\"_source\":{\"text\":[\"a b\"]},\"_id\":\"1\",\"_score\":0.2")); + } + + public void testVectorDBToolInFlowAgent_withIllegalSourceField_thenGetEmptySource() { + String agentId = createAgent(registerAgentRequestBody.replace("text", "text2")); + String result = executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}"); + + // To allow digits variation from model output, using string contains to match + assertTrue(result.contains("{\"_index\":\"test_index\",\"_source\":{},\"_id\":\"0\",\"_score\":0.70493")); + assertTrue(result.contains("{\"_index\":\"test_index\",\"_source\":{},\"_id\":\"1\",\"_score\":0.26505")); + + } + + public void testVectorDBToolInFlowAgent_withIllegalEmbeddingField_thenThrowException() { + String agentId = createAgent(registerAgentRequestBody.replace("\"embedding\"", "\"embedding2\"")); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf(containsString("Field 'embedding2' is not knn_vector type."), containsString("IllegalArgumentException")) + ); + } + + public void testVectorDBToolInFlowAgent_withIllegalIndexField_thenThrowException() { + String agentId = createAgent(registerAgentRequestBody.replace("test_index", "test_index2")); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + + org.hamcrest.MatcherAssert + .assertThat( + exception.getMessage(), + allOf(containsString("no such index [test_index2]"), containsString("IndexNotFoundException")) + ); + } + + public void testVectorDBToolInFlowAgent_withIllegalModelIdField_thenThrowException() { + String agentId = createAgent(registerAgentRequestBody.replace(modelId, "test_model_id")); + Exception exception = assertThrows(ResponseException.class, () -> executeAgent(agentId, "{\"parameters\": {\"question\": \"a\"}}")); + + org.hamcrest.MatcherAssert + .assertThat(exception.getMessage(), allOf(containsString("Failed to find model"), containsString("OpenSearchStatusException"))); + } +} diff --git a/src/test/resources/org/opensearch/agent/tools/alerting/alert_index_mappings.json b/src/test/resources/org/opensearch/agent/tools/alerting/alert_index_mappings.json new file mode 100644 index 00000000..c15410f5 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/alerting/alert_index_mappings.json @@ -0,0 +1,181 @@ +{ + "mappings": { + "dynamic": "strict", + "_meta": { + "schema_version": 5 + }, + "properties": { + "schema_version": { + "type": "integer" + }, + "monitor_id": { + "type": "keyword" + }, + "monitor_version": { + "type": "long" + }, + "id": { + "type": "keyword" + }, + "version": { + "type": "long" + }, + "severity": { + "type": "keyword" + }, + "monitor_name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "monitor_user": { + "properties": { + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "backend_roles": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + }, + "roles": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + }, + "custom_attribute_names": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + } + } + }, + "execution_id": { + "type": "keyword" + }, + "workflow_id": { + "type": "keyword" + }, + "workflow_name": { + "type": "keyword" + }, + "trigger_id": { + "type": "keyword" + }, + "trigger_name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "finding_ids": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + }, + "associated_alert_ids": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + }, + "related_doc_ids": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + }, + "state": { + "type": "keyword" + }, + "start_time": { + "type": "date" + }, + "last_notification_time": { + "type": "date" + }, + "acknowledged_time": { + "type": "date" + }, + "end_time": { + "type": "date" + }, + "error_message": { + "type": "text" + }, + "alert_history": { + "type": "nested", + "properties": { + "timestamp": { + "type": "date" + }, + "message": { + "type": "text" + } + } + }, + "action_execution_results": { + "type": "nested", + "properties": { + "action_id": { + "type": "keyword" + }, + "last_execution_time": { + "type": "date" + }, + "throttled_count": { + "type": "integer" + } + } + }, + "agg_alert_content": { + "dynamic": true, + "properties": { + "parent_bucket_path": { + "type": "text" + }, + "bucket_key": { + "type": "text" + } + } + }, + "clusters": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + } + } + } +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/alerting/alerting_config_index_mappings.json b/src/test/resources/org/opensearch/agent/tools/alerting/alerting_config_index_mappings.json new file mode 100644 index 00000000..759cd448 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/alerting/alerting_config_index_mappings.json @@ -0,0 +1,1269 @@ +{ + "mappings": { + "_meta": { + "schema_version": 8 + }, + "properties": { + "audit_delegate_monitor_alerts": { + "type": "boolean" + }, + "data_sources": { + "properties": { + "alerts_history_index": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "alerts_history_index_pattern": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "alerts_index": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "findings_enabled": { + "type": "boolean" + }, + "findings_index": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "findings_index_pattern": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "query_index": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "query_index_mappings_by_type": { + "type": "object" + } + } + }, + "destination": { + "dynamic": "false", + "properties": { + "chime": { + "properties": { + "url": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + } + } + }, + "custom_webhook": { + "properties": { + "header_params": { + "type": "object", + "enabled": false + }, + "host": { + "type": "text" + }, + "password": { + "type": "text" + }, + "path": { + "type": "keyword" + }, + "port": { + "type": "integer" + }, + "query_params": { + "type": "object", + "enabled": false + }, + "scheme": { + "type": "keyword" + }, + "url": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "username": { + "type": "text" + } + } + }, + "email": { + "properties": { + "email_account_id": { + "type": "keyword" + }, + "recipients": { + "type": "nested", + "properties": { + "email": { + "type": "text" + }, + "email_group_id": { + "type": "keyword" + }, + "type": { + "type": "keyword" + } + } + } + } + }, + "last_update_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "schema_version": { + "type": "integer" + }, + "slack": { + "properties": { + "url": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + } + } + }, + "type": { + "type": "keyword" + }, + "user": { + "properties": { + "backend_roles": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + }, + "custom_attribute_names": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + }, + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "roles": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + } + } + } + } + }, + "email_account": { + "properties": { + "from": { + "type": "text" + }, + "host": { + "type": "text" + }, + "method": { + "type": "text" + }, + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "port": { + "type": "integer" + } + } + }, + "email_group": { + "properties": { + "emails": { + "type": "nested", + "properties": { + "email": { + "type": "text" + } + } + }, + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + } + } + }, + "enabled": { + "type": "boolean" + }, + "inputs": { + "properties": { + "composite_input": { + "properties": { + "sequence": { + "properties": { + "delegates": { + "properties": { + "monitor_id": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "order": { + "type": "long" + } + } + } + } + } + } + }, + "doc_level_input": { + "properties": { + "description": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "indices": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "queries": { + "properties": { + "id": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "query": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + } + } + } + } + }, + "search": { + "properties": { + "indices": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "query": { + "properties": { + "aggregations": { + "properties": { + "metric": { + "properties": { + "avg": { + "properties": { + "field": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + } + } + } + } + } + } + }, + "query": { + "properties": { + "bool": { + "properties": { + "adjust_pure_negative": { + "type": "boolean" + }, + "boost": { + "type": "long" + }, + "filter": { + "properties": { + "range": { + "properties": { + "dayOfWeek": { + "properties": { + "boost": { + "type": "long" + }, + "from": { + "type": "long" + }, + "include_lower": { + "type": "boolean" + }, + "include_upper": { + "type": "boolean" + }, + "to": { + "type": "long" + } + } + }, + "timestamp": { + "properties": { + "boost": { + "type": "long" + }, + "format": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "from": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "include_lower": { + "type": "boolean" + }, + "include_upper": { + "type": "boolean" + }, + "to": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + } + } + } + } + }, + "term": { + "properties": { + "dayOfWeek": { + "properties": { + "boost": { + "type": "long" + }, + "value": { + "type": "long" + } + } + } + } + } + } + } + } + } + } + }, + "size": { + "type": "long" + } + } + } + } + }, + "uri": { + "properties": { + "api_type": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "path": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "path_params": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "url": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + } + } + } + } + }, + "last_update_time": { + "type": "long" + }, + "metadata": { + "properties": { + "last_action_execution_times": { + "type": "nested", + "properties": { + "action_id": { + "type": "keyword" + }, + "execution_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + } + } + }, + "last_run_context": { + "type": "object", + "enabled": false + }, + "monitor_id": { + "type": "keyword" + }, + "source_to_query_index_mapping": { + "type": "object", + "enabled": false + } + } + }, + "monitor": { + "dynamic": "false", + "properties": { + "data_sources": { + "properties": { + "alerts_index": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "findings_index": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "query_index": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "query_index_mapping": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + } + } + }, + "enabled": { + "type": "boolean" + }, + "enabled_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "group_by_fields": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "inputs": { + "type": "nested", + "properties": { + "search": { + "properties": { + "indices": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "query": { + "type": "object", + "enabled": false + } + } + } + } + }, + "last_update_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "monitor_type": { + "type": "keyword" + }, + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "owner": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "schedule": { + "properties": { + "cron": { + "properties": { + "expression": { + "type": "text" + }, + "timezone": { + "type": "keyword" + } + } + }, + "period": { + "properties": { + "interval": { + "type": "integer" + }, + "unit": { + "type": "keyword" + } + } + } + } + }, + "schema_version": { + "type": "integer" + }, + "triggers": { + "type": "nested", + "properties": { + "actions": { + "type": "nested", + "properties": { + "destination_id": { + "type": "keyword" + }, + "message_template": { + "type": "object", + "enabled": false + }, + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "subject_template": { + "type": "object", + "enabled": false + }, + "throttle": { + "properties": { + "unit": { + "type": "keyword" + }, + "value": { + "type": "integer" + } + } + }, + "throttle_enabled": { + "type": "boolean" + } + } + }, + "condition": { + "type": "object", + "enabled": false + }, + "min_time_between_executions": { + "type": "integer" + }, + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "query_level_trigger": { + "properties": { + "actions": { + "type": "nested", + "properties": { + "destination_id": { + "type": "keyword" + }, + "message_template": { + "type": "object", + "enabled": false + }, + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "subject_template": { + "type": "object", + "enabled": false + }, + "throttle": { + "properties": { + "unit": { + "type": "keyword" + }, + "value": { + "type": "integer" + } + } + }, + "throttle_enabled": { + "type": "boolean" + } + } + }, + "condition": { + "type": "object", + "enabled": false + }, + "min_time_between_executions": { + "type": "integer" + }, + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + } + } + } + } + }, + "type": { + "type": "keyword" + }, + "ui_metadata": { + "type": "object", + "enabled": false + }, + "user": { + "properties": { + "backend_roles": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + }, + "custom_attribute_names": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + }, + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "roles": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + } + } + } + } + }, + "monitor_type": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "owner": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "schedule": { + "properties": { + "period": { + "properties": { + "interval": { + "type": "long" + }, + "unit": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + } + } + } + } + }, + "schema_version": { + "type": "long" + }, + "triggers": { + "properties": { + "chained_alert_trigger": { + "properties": { + "condition": { + "properties": { + "script": { + "properties": { + "lang": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "source": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + } + } + } + } + }, + "id": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "severity": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + } + } + }, + "document_level_trigger": { + "properties": { + "condition": { + "properties": { + "script": { + "properties": { + "lang": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "source": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + } + } + } + } + }, + "id": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "severity": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + } + } + }, + "query_level_trigger": { + "properties": { + "condition": { + "properties": { + "script": { + "properties": { + "lang": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "source": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + } + } + } + } + }, + "id": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "severity": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + } + } + } + } + }, + "type": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "workflow": { + "dynamic": "false", + "properties": { + "audit_delegate_monitor_alerts": { + "type": "boolean" + }, + "enabled": { + "type": "boolean" + }, + "enabled_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "group_by_fields": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "inputs": { + "type": "nested", + "properties": { + "composite_input": { + "type": "nested", + "properties": { + "sequence": { + "properties": { + "delegates": { + "type": "nested", + "properties": { + "chained_monitor_findings": { + "properties": { + "monitor_id": { + "type": "keyword" + } + } + }, + "monitor_id": { + "type": "keyword" + }, + "order": { + "type": "integer" + } + } + } + } + } + } + } + } + }, + "last_update_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "owner": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "schedule": { + "properties": { + "cron": { + "properties": { + "expression": { + "type": "text" + }, + "timezone": { + "type": "keyword" + } + } + }, + "period": { + "properties": { + "interval": { + "type": "integer" + }, + "unit": { + "type": "keyword" + } + } + } + } + }, + "schema_version": { + "type": "integer" + }, + "type": { + "type": "keyword" + }, + "user": { + "properties": { + "backend_roles": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + }, + "custom_attribute_names": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + }, + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "roles": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + } + } + }, + "workflow_type": { + "type": "keyword" + } + } + }, + "workflow_metadata": { + "properties": { + "latest_execution_id": { + "type": "keyword" + }, + "latest_run_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "monitor_ids": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 1000 + } + } + }, + "workflow_id": { + "type": "keyword" + } + } + }, + "workflow_type": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + } + } + } +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/alerting/register_flow_agent_of_search_alerts_tool_request_body.json b/src/test/resources/org/opensearch/agent/tools/alerting/register_flow_agent_of_search_alerts_tool_request_body.json new file mode 100644 index 00000000..71abfc78 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/alerting/register_flow_agent_of_search_alerts_tool_request_body.json @@ -0,0 +1,9 @@ +{ + "name": "Test_Search_Alerts_Agent", + "type": "flow", + "tools": [ + { + "type": "SearchAlertsTool" + } + ] +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/alerting/register_flow_agent_of_search_monitors_tool_request_body.json b/src/test/resources/org/opensearch/agent/tools/alerting/register_flow_agent_of_search_monitors_tool_request_body.json new file mode 100644 index 00000000..200203fd --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/alerting/register_flow_agent_of_search_monitors_tool_request_body.json @@ -0,0 +1,9 @@ +{ + "name": "Test_Search_Monitors_Agent", + "type": "flow", + "tools": [ + { + "type": "SearchMonitorsTool" + } + ] +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/alerting/sample_alert.json b/src/test/resources/org/opensearch/agent/tools/alerting/sample_alert.json new file mode 100644 index 00000000..2a15c7a7 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/alerting/sample_alert.json @@ -0,0 +1,24 @@ +{ + "monitor_id": "foo-monitor-id", + "workflow_id": "", + "workflow_name": "", + "associated_alert_ids": [], + "schema_version": 5, + "monitor_version": 1, + "monitor_name": "foo-monitor", + "execution_id": "foo-execution-id", + "trigger_id": "foo-trigger-id", + "trigger_name": "foo-trigger-name", + "finding_ids": [], + "related_doc_ids": [], + "state": "COMPLETED", + "error_message": null, + "alert_history": [], + "severity": "2", + "action_execution_results": [], + "start_time": 1234, + "last_notification_time": 1234, + "end_time": 1234, + "acknowledged_time": null, + "clusters": [] +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/alerting/sample_monitor.json b/src/test/resources/org/opensearch/agent/tools/alerting/sample_monitor.json new file mode 100644 index 00000000..d7e071a4 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/alerting/sample_monitor.json @@ -0,0 +1,15 @@ +{ + "type": "monitor", + "schema_version": 0, + "name": "foo-monitor", + "monitor_type": "query_level_monitor", + "enabled": true, + "schedule": { + "period": { + "interval": 1, + "unit": "MINUTES" + } + }, + "inputs": [], + "triggers": [] +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/anomaly-detection/register_flow_agent_of_search_anomaly_detectors_tool_request_body.json b/src/test/resources/org/opensearch/agent/tools/anomaly-detection/register_flow_agent_of_search_anomaly_detectors_tool_request_body.json new file mode 100644 index 00000000..b65eb44e --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/anomaly-detection/register_flow_agent_of_search_anomaly_detectors_tool_request_body.json @@ -0,0 +1,9 @@ +{ + "name": "Test_Search_Detectors_Agent", + "type": "flow", + "tools": [ + { + "type": "SearchAnomalyDetectorsTool" + } + ] +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/anomaly-detection/register_flow_agent_of_search_anomaly_results_tool_request_body.json b/src/test/resources/org/opensearch/agent/tools/anomaly-detection/register_flow_agent_of_search_anomaly_results_tool_request_body.json new file mode 100644 index 00000000..03e2f354 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/anomaly-detection/register_flow_agent_of_search_anomaly_results_tool_request_body.json @@ -0,0 +1,9 @@ +{ + "name": "Test_Search_Anomaly_Results_Agent", + "type": "flow", + "tools": [ + { + "type": "SearchAnomalyResultsTool" + } + ] +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/anomaly-detection/results_index_mappings.json b/src/test/resources/org/opensearch/agent/tools/anomaly-detection/results_index_mappings.json new file mode 100644 index 00000000..ee4e5e26 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/anomaly-detection/results_index_mappings.json @@ -0,0 +1,161 @@ +{ + "mappings": { + "dynamic": "false", + "_meta": { + "schema_version": 5 + }, + "properties": { + "anomaly_grade": { + "type": "double" + }, + "anomaly_score": { + "type": "double" + }, + "approx_anomaly_start_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "confidence": { + "type": "double" + }, + "data_end_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "data_start_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "detector_id": { + "type": "keyword" + }, + "entity": { + "type": "nested", + "properties": { + "name": { + "type": "keyword" + }, + "value": { + "type": "keyword" + } + } + }, + "error": { + "type": "text" + }, + "execution_end_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "execution_start_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "expected_values": { + "type": "nested", + "properties": { + "likelihood": { + "type": "double" + }, + "value_list": { + "type": "nested", + "properties": { + "data": { + "type": "double" + }, + "feature_id": { + "type": "keyword" + } + } + } + } + }, + "feature_data": { + "type": "nested", + "properties": { + "data": { + "type": "double" + }, + "feature_id": { + "type": "keyword" + } + } + }, + "is_anomaly": { + "type": "boolean" + }, + "model_id": { + "type": "keyword" + }, + "past_values": { + "type": "nested", + "properties": { + "data": { + "type": "double" + }, + "feature_id": { + "type": "keyword" + } + } + }, + "relevant_attribution": { + "type": "nested", + "properties": { + "data": { + "type": "double" + }, + "feature_id": { + "type": "keyword" + } + } + }, + "schema_version": { + "type": "integer" + }, + "task_id": { + "type": "keyword" + }, + "threshold": { + "type": "double" + }, + "user": { + "type": "nested", + "properties": { + "backend_roles": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + }, + "custom_attribute_names": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + }, + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "roles": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + } + } + } + } + } +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/anomaly-detection/sample_detector.json b/src/test/resources/org/opensearch/agent/tools/anomaly-detection/sample_detector.json new file mode 100644 index 00000000..b23a3e99 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/anomaly-detection/sample_detector.json @@ -0,0 +1,48 @@ +{ + "name": "test-detector", + "description": "Test detector", + "time_field": "timestamp", + "indices": [ + "test-index" + ], + "feature_attributes": [ + { + "feature_name": "test", + "feature_enabled": true, + "aggregation_query": { + "test": { + "sum": { + "field": "value" + } + } + } + } + ], + "filter_query": { + "bool": { + "filter": [ + { + "range": { + "value": { + "gt": 1 + } + } + } + ], + "adjust_pure_negative": true, + "boost": 1 + } + }, + "detection_interval": { + "period": { + "interval": 1, + "unit": "Minutes" + } + }, + "window_delay": { + "period": { + "interval": 1, + "unit": "Minutes" + } + } +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/anomaly-detection/sample_index_mappings.json b/src/test/resources/org/opensearch/agent/tools/anomaly-detection/sample_index_mappings.json new file mode 100644 index 00000000..0697e7bf --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/anomaly-detection/sample_index_mappings.json @@ -0,0 +1,12 @@ +{ + "mappings": { + "properties": { + "value": { + "type": "integer" + }, + "timestamp": { + "type": "date" + } + } + } +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/anomaly-detection/sample_result.json b/src/test/resources/org/opensearch/agent/tools/anomaly-detection/sample_result.json new file mode 100644 index 00000000..d81a4c32 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/anomaly-detection/sample_result.json @@ -0,0 +1,19 @@ +{ + "detector_id": "foo-id", + "schema_version": 5, + "data_start_time": 1234, + "data_end_time": 1234, + "feature_data": [ + { + "feature_id": "foo-feature-id", + "feature_name": "foo-feature-name", + "data": 1 + } + ], + "execution_start_time": 1234, + "execution_end_time": 1234, + "anomaly_score": 0.5, + "anomaly_grade": 0.5, + "confidence": 0.5, + "threshold": 0.8 +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/neural_sparse_tool_search_response.json b/src/test/resources/org/opensearch/agent/tools/neural_sparse_tool_search_response.json new file mode 100644 index 00000000..196e8a04 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/neural_sparse_tool_search_response.json @@ -0,0 +1,71 @@ +{ + "took" : 688, + "timed_out" : false, + "_shards" : { + "total" : 1, + "successful" : 1, + "skipped" : 0, + "failed" : 0 + }, + "hits" : { + "total" : { + "value" : 2, + "relation" : "eq" + }, + "max_score" : 30.0029, + "hits" : [ + { + "_index" : "my-nlp-index", + "_id" : "1", + "_score" : 30.0029, + "_source" : { + "passage_text" : "Hello world", + "passage_embedding" : { + "!" : 0.8708904, + "door" : 0.8587369, + "hi" : 2.3929274, + "worlds" : 2.7839446, + "yes" : 0.75845814, + "##world" : 2.5432441, + "born" : 0.2682308, + "nothing" : 0.8625516, + "goodbye" : 0.17146169, + "greeting" : 0.96817183, + "birth" : 1.2788506, + "come" : 0.1623208, + "global" : 0.4371151, + "it" : 0.42951578, + "life" : 1.5750692, + "thanks" : 0.26481047, + "world" : 4.7300377, + "tiny" : 0.5462298, + "earth" : 2.6555297, + "universe" : 2.0308156, + "worldwide" : 1.3903781, + "hello" : 6.696973, + "so" : 0.20279501, + "?" : 0.67785245 + }, + "id" : "s1" + } + }, + { + "_index" : "my-nlp-index", + "_id" : "2", + "_score" : 16.480486, + "_source" : { + "passage_text" : "Hi planet", + "passage_embedding" : { + "hi" : 4.338913, + "planets" : 2.7755864, + "planet" : 5.0969057, + "mars" : 1.7405145, + "earth" : 2.6087382, + "hello" : 3.3210192 + }, + "id" : "s2" + } + } + ] + } +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_create_anomaly_detector_tool_request_body.json b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_create_anomaly_detector_tool_request_body.json new file mode 100644 index 00000000..3ad9477e --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_create_anomaly_detector_tool_request_body.json @@ -0,0 +1,12 @@ +{ + "name": "Test_create_anomaly_detector_flow_agent", + "type": "flow", + "tools": [ + { + "type": "CreateAnomalyDetectorTool", + "parameters": { + "model_id": "" + } + } + ] +} diff --git a/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_neural_sparse_search_tool_request_body.json b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_neural_sparse_search_tool_request_body.json new file mode 100644 index 00000000..579f0778 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_neural_sparse_search_tool_request_body.json @@ -0,0 +1,18 @@ +{ + "name": "Test_Neural_Sparse_Agent_For_RAG", + "type": "flow", + "tools": [ + { + "type": "NeuralSparseSearchTool", + "parameters": { + "description":"user this tool to search data from the test index", + "model_id": "", + "index": "test_index", + "embedding_field": "embedding", + "source_field": ["text"], + "input": "${parameters.question}", + "nested_path": "" + } + } + ] +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_ppl_tool_request_body.json b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_ppl_tool_request_body.json new file mode 100644 index 00000000..0e1a167e --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_ppl_tool_request_body.json @@ -0,0 +1,13 @@ +{ + "name": "Test_PPL_Agent_For_RAG", + "type": "flow", + "tools": [ + { + "type": "PPLTool", + "parameters": { + "model_id": "", + "prompt": "Below is an instruction that describes a task, paired with the index and corresponding fields that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nI have an opensearch index with fields in the following. Now I have a question: ${indexInfo.question} Can you help me generate a PPL for that?\n\n### Index:\n${indexInfo.mappingInfo}\n\n### Fields:\n${indexInfo.indexName}\n\n### Response:\n" + } + } + ] +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_ragtool_with_neural_query_type_request_body.json b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_ragtool_with_neural_query_type_request_body.json new file mode 100644 index 00000000..7f8a4af2 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_ragtool_with_neural_query_type_request_body.json @@ -0,0 +1,23 @@ +{ + "name": "Test_Agent_For_RagTool", + "type": "flow", + "description": "this is a test flow agent in flow", + "tools": [ + { + "type": "RAGTool", + "description": "A description of the tool", + "parameters": { + "embedding_model_id": "", + "index": "", + "embedding_field": "embedding", + "query_type": "neural", + "enable_content_generation":"false", + "source_field": [ + "text" + ], + "input": "${parameters.question}", + "prompt": "\n\nHuman:You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say don't know. \n\n Context:\n${parameters.output_field}\n\nHuman:${parameters.question}\n\nAssistant:" + } + } + ] +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_vectordb_tool_request_body.json b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_vectordb_tool_request_body.json new file mode 100644 index 00000000..b1488388 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/register_flow_agent_of_vectordb_tool_request_body.json @@ -0,0 +1,18 @@ +{ + "name": "Test_VectorDB_Agent", + "type": "flow", + "tools": [ + { + "type": "VectorDBTool", + "parameters": { + "description":"user this tool to search data from the test index", + "model_id": "", + "index": "test_index", + "embedding_field": "embedding", + "source_field": ["text"], + "input": "${parameters.question}", + "nested_path": "" + } + } + ] +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/register_sparse_encoding_model_request_body.json b/src/test/resources/org/opensearch/agent/tools/register_sparse_encoding_model_request_body.json new file mode 100644 index 00000000..8eb7901c --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/register_sparse_encoding_model_request_body.json @@ -0,0 +1,5 @@ +{ + "name":"amazon/neural-sparse/opensearch-neural-sparse-tokenizer-v1", + "version":"1.0.1", + "model_format": "TORCH_SCRIPT" +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/register_text_embedding_model_request_body.json b/src/test/resources/org/opensearch/agent/tools/register_text_embedding_model_request_body.json new file mode 100644 index 00000000..0173665a --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/register_text_embedding_model_request_body.json @@ -0,0 +1,14 @@ +{ + "name": "traced_small_model", + "version": "1.0.0", + "model_format": "TORCH_SCRIPT", + "model_task_type": "text_embedding", + "model_content_hash_value": "e13b74006290a9d0f58c1376f9629d4ebc05a0f9385f40db837452b167ae9021", + "model_config": { + "model_type": "bert", + "embedding_dimension": 768, + "framework_type": "sentence_transformers", + "all_config": "{\"architectures\":[\"BertModel\"],\"max_position_embeddings\":512,\"model_type\":\"bert\",\"num_attention_heads\":12,\"num_hidden_layers\":6}" + }, + "url": "https://github.com/opensearch-project/ml-commons/blob/2.x/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_embedding/traced_small_model.zip?raw=true" +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/retrieval_tool_empty_search_response.json b/src/test/resources/org/opensearch/agent/tools/retrieval_tool_empty_search_response.json new file mode 100644 index 00000000..7ca6bfa7 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/retrieval_tool_empty_search_response.json @@ -0,0 +1,18 @@ +{ + "took": 4, + "timed_out": false, + "_shards": { + "total": 1, + "successful": 1, + "skipped": 0, + "failed": 0 + }, + "hits": { + "total": { + "value": 0, + "relation": "eq" + }, + "max_score": null, + "hits": [] + } +} \ No newline at end of file diff --git a/src/test/resources/org/opensearch/agent/tools/retrieval_tool_search_response.json b/src/test/resources/org/opensearch/agent/tools/retrieval_tool_search_response.json new file mode 100644 index 00000000..d89ad3b0 --- /dev/null +++ b/src/test/resources/org/opensearch/agent/tools/retrieval_tool_search_response.json @@ -0,0 +1,35 @@ +{ + "took": 201, + "timed_out": false, + "_shards": { + "total": 1, + "successful": 1, + "skipped": 0, + "failed": 0 + }, + "hits": { + "total": { + "value": 2, + "relation": "eq" + }, + "max_score": 89.2917, + "hits": [ + { + "_index": "hybrid-index", + "_id": "1", + "_score": 89.2917, + "_source": { + "passage_text": "Company test_mock have a history of 100 years." + } + }, + { + "_index": "hybrid-index", + "_id": "2", + "_score": 0.10702579, + "_source": { + "passage_text": "the price of the api is 2$ per invocation" + } + } + ] + } +} \ No newline at end of file