diff --git a/alerting/src/main/kotlin/org/opensearch/alerting/resthandler/RestGetWorkflowAlertsAction.kt b/alerting/src/main/kotlin/org/opensearch/alerting/resthandler/RestGetWorkflowAlertsAction.kt index 5fb7d8ffc..474c32d4a 100644 --- a/alerting/src/main/kotlin/org/opensearch/alerting/resthandler/RestGetWorkflowAlertsAction.kt +++ b/alerting/src/main/kotlin/org/opensearch/alerting/resthandler/RestGetWorkflowAlertsAction.kt @@ -55,11 +55,16 @@ class RestGetWorkflowAlertsAction : BaseRestHandler() { val severityLevel = request.param("severityLevel", "ALL") val alertState = request.param("alertState", "ALL") val workflowId: String? = request.param("workflowIds") + val alertId: String? = request.param("alertIds") val getAssociatedAlerts: Boolean = request.param("getAssociatedAlerts", "false").toBoolean() val workflowIds = mutableListOf() if (workflowId.isNullOrEmpty() == false) { workflowIds.add(workflowId) } + val alertIds = mutableListOf() + if (alertId.isNullOrEmpty() == false) { + alertIds.add(alertId) + } val table = Table( sortOrder, sortString, @@ -77,7 +82,8 @@ class RestGetWorkflowAlertsAction : BaseRestHandler() { associatedAlertsIndex = null, workflowIds = workflowIds, monitorIds = emptyList(), - getAssociatedAlerts = getAssociatedAlerts + getAssociatedAlerts = getAssociatedAlerts, + alertIds = alertIds ) return RestChannelConsumer { channel -> client.execute(AlertingActions.GET_WORKFLOW_ALERTS_ACTION_TYPE, getWorkflowAlertsRequest, RestToXContentListener(channel)) diff --git a/alerting/src/main/kotlin/org/opensearch/alerting/transport/TransportGetWorkflowAlertsAction.kt b/alerting/src/main/kotlin/org/opensearch/alerting/transport/TransportGetWorkflowAlertsAction.kt index c2df4baf1..b1c9209ef 100644 --- a/alerting/src/main/kotlin/org/opensearch/alerting/transport/TransportGetWorkflowAlertsAction.kt +++ b/alerting/src/main/kotlin/org/opensearch/alerting/transport/TransportGetWorkflowAlertsAction.kt @@ -125,13 +125,18 @@ class TransportGetWorkflowAlertsAction @Inject constructor( .field("trigger_name") ) } + // if alert id is mentioned we cannot set "from" field as it may not return id. we would be using it to paginate associated alerts + val from = if (getWorkflowAlertsRequest.alertIds.isNullOrEmpty()) + tableProp.startIndex + else 0 + val searchSourceBuilder = SearchSourceBuilder() .version(true) .seqNoAndPrimaryTerm(true) .query(queryBuilder) .sort(sortBuilder) .size(tableProp.size) - .from(tableProp.startIndex) + .from(from) client.threadPool().threadContext.stashContext().use { scope.launch { @@ -205,22 +210,42 @@ class TransportGetWorkflowAlertsAction @Inject constructor( parseAlertsFromSearchResponse(response) ) if (alerts.isNotEmpty() && getWorkflowAlertsRequest.getAssociatedAlerts == true) - getAssociatedAlerts(associatedAlerts, alerts, resolveAssociatedAlertsIndexName(getWorkflowAlertsRequest)) + getAssociatedAlerts( + associatedAlerts, + alerts, + resolveAssociatedAlertsIndexName(getWorkflowAlertsRequest), + getWorkflowAlertsRequest + ) actionListener.onResponse(GetWorkflowAlertsResponse(alerts, associatedAlerts, totalAlertCount)) } catch (e: Exception) { actionListener.onFailure(AlertingException("Failed to get alerts", RestStatus.INTERNAL_SERVER_ERROR, e)) } } - private suspend fun getAssociatedAlerts(associatedAlerts: MutableList, alerts: MutableList, alertIndex: String) { + private suspend fun getAssociatedAlerts( + associatedAlerts: MutableList, + alerts: MutableList, + alertIndex: String, + getWorkflowAlertsRequest: GetWorkflowAlertsRequest, + ) { try { val associatedAlertIds = mutableSetOf() alerts.forEach { associatedAlertIds.addAll(it.associatedAlertIds) } if (associatedAlertIds.isEmpty()) return val queryBuilder = QueryBuilders.boolQuery() + val searchRequest = SearchRequest(alertIndex) + // if chained alert id param is non-null, paginate the associated alerts. + if (getWorkflowAlertsRequest.alertIds.isNullOrEmpty() == false) { + val tableProp = getWorkflowAlertsRequest.table + val sortBuilder = SortBuilders.fieldSort(tableProp.sortString) + .order(SortOrder.fromString(tableProp.sortOrder)) + if (!tableProp.missing.isNullOrBlank()) { + sortBuilder.missing(tableProp.missing) + } + searchRequest.source().sort(sortBuilder).size(tableProp.size).from(tableProp.startIndex) + } queryBuilder.must(QueryBuilders.termsQuery("_id", associatedAlertIds)) queryBuilder.must(QueryBuilders.termQuery(Alert.STATE_FIELD, Alert.State.AUDIT)) - val searchRequest = SearchRequest(alertIndex) searchRequest.source().query(queryBuilder) val response: SearchResponse = client.suspendUntil { search(searchRequest, it) } associatedAlerts.addAll(parseAlertsFromSearchResponse(response)) diff --git a/alerting/src/test/kotlin/org/opensearch/alerting/AlertingRestTestCase.kt b/alerting/src/test/kotlin/org/opensearch/alerting/AlertingRestTestCase.kt index bdc568d24..d58b15c75 100644 --- a/alerting/src/test/kotlin/org/opensearch/alerting/AlertingRestTestCase.kt +++ b/alerting/src/test/kotlin/org/opensearch/alerting/AlertingRestTestCase.kt @@ -814,9 +814,13 @@ abstract class AlertingRestTestCase : ODFERestTestCase() { protected fun getWorkflowAlerts( workflowId: String, + alertId: String? = "", getAssociatedAlerts: Boolean = true, ): Response { - return getWorkflowAlerts(client(), mutableMapOf(Pair("workflowIds", workflowId), Pair("getAssociatedAlerts", getAssociatedAlerts))) + return getWorkflowAlerts( + client(), + mutableMapOf(Pair("workflowIds", workflowId), Pair("getAssociatedAlerts", getAssociatedAlerts), Pair("alertIds", alertId!!)) + ) } protected fun getWorkflowAlerts( diff --git a/alerting/src/test/kotlin/org/opensearch/alerting/MonitorDataSourcesIT.kt b/alerting/src/test/kotlin/org/opensearch/alerting/MonitorDataSourcesIT.kt index c19e8b556..752ccd150 100644 --- a/alerting/src/test/kotlin/org/opensearch/alerting/MonitorDataSourcesIT.kt +++ b/alerting/src/test/kotlin/org/opensearch/alerting/MonitorDataSourcesIT.kt @@ -6,6 +6,7 @@ package org.opensearch.alerting import org.junit.Assert +import org.opensearch.action.DocWriteRequest import org.opensearch.action.admin.cluster.state.ClusterStateRequest import org.opensearch.action.admin.indices.alias.Alias import org.opensearch.action.admin.indices.close.CloseIndexRequest @@ -17,6 +18,8 @@ import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest import org.opensearch.action.admin.indices.mapping.put.PutMappingRequest import org.opensearch.action.admin.indices.open.OpenIndexRequest import org.opensearch.action.admin.indices.refresh.RefreshRequest +import org.opensearch.action.bulk.BulkRequest +import org.opensearch.action.bulk.BulkResponse import org.opensearch.action.fieldcaps.FieldCapabilitiesRequest import org.opensearch.action.index.IndexRequest import org.opensearch.action.search.SearchRequest @@ -3693,7 +3696,7 @@ class MonitorDataSourcesIT : AlertingSingleNodeTestCase() { alertsIndex: String? = AlertIndices.ALERT_INDEX, executionId: String? = null, alertSize: Int, - workflowId: String + workflowId: String, ): GetAlertsResponse { val alerts = searchAlerts(monitorId, alertsIndex!!, executionId = executionId) assertEquals("Alert saved for test monitor", alertSize, alerts.size) @@ -5639,4 +5642,83 @@ class MonitorDataSourcesIT : AlertingSingleNodeTestCase() { } Assert.assertTrue(alerts.stream().anyMatch { it.state == Alert.State.DELETED && chainedAlerts[0].id == it.id }) } + + fun `test get chained alerts with alertId paginating for associated alerts`() { + val docQuery1 = DocLevelQuery(query = "test_field_1:\"us-west-2\"", name = "3") + val docLevelInput1 = DocLevelMonitorInput("description", listOf(index), listOf(docQuery1)) + val trigger1 = randomDocumentLevelTrigger(condition = ALWAYS_RUN) + var monitor1 = randomDocumentLevelMonitor( + inputs = listOf(docLevelInput1), + triggers = listOf(trigger1) + ) + var monitor2 = randomDocumentLevelMonitor( + inputs = listOf(docLevelInput1), + triggers = listOf(trigger1) + ) + val monitorResponse = createMonitor(monitor1)!! + val monitorResponse2 = createMonitor(monitor2)!! + + val andTrigger = randomChainedAlertTrigger( + name = "1And2", + condition = Script("monitor[id=${monitorResponse.id}] && monitor[id=${monitorResponse2.id}]") + ) + var workflow = randomWorkflow( + monitorIds = listOf(monitorResponse.id, monitorResponse2.id), + triggers = listOf(andTrigger) + ) + val workflowResponse = upsertWorkflow(workflow)!! + val workflowById = searchWorkflow(workflowResponse.id) + val workflowId = workflowById!!.id + val testTime = DateTimeFormatter.ISO_OFFSET_DATE_TIME.format(ZonedDateTime.now().truncatedTo(MILLIS)) + val testDoc1 = """{ + "message" : "This is an error from IAD region", + "source.ip.v6.v2" : 16644, + "test_strict_date_time" : "$testTime", + "test_field_1" : "us-west-2" + }""" + var i = 1 + val indexRequests = mutableListOf() + while (i++ < 300) { + indexRequests += IndexRequest(index).source(testDoc1, XContentType.JSON).id("$i").opType(DocWriteRequest.OpType.INDEX) + } + val bulkResponse: BulkResponse = + client().bulk(BulkRequest().add(indexRequests).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)).get() + if (bulkResponse.hasFailures()) { + fail("Bulk request to index to test index has failed") + } + var executeWorkflowResponse = executeWorkflow(workflowById, workflowId, false)!! + var res = getWorkflowAlerts( + workflowId = workflowId + ) + Assert.assertTrue(executeWorkflowResponse.workflowRunResult.triggerResults[andTrigger.id]!!.triggered) + var chainedAlerts = res.alerts + Assert.assertTrue(chainedAlerts.size == 1) + Assert.assertEquals(res.associatedAlerts.size, 10) + var res100to200 = getWorkflowAlerts( + workflowId = workflowId, + alertIds = listOf(res.alerts[0].id), + table = Table("asc", "monitor_id", null, 100, 100, null) + ) + Assert.assertEquals(res100to200.associatedAlerts.size, 100) + var res200to300 = getWorkflowAlerts( + workflowId = workflowId, + alertIds = listOf(res.alerts[0].id), + table = Table("asc", "monitor_id", null, 100, 201, null) + ) + Assert.assertEquals(res200to300.associatedAlerts.size, 100) + var res0to99 = getWorkflowAlerts( + workflowId = workflowId, + alertIds = listOf(res.alerts[0].id), + table = Table("asc", "monitor_id", null, 100, 0, null) + ) + Assert.assertEquals(res0to99.associatedAlerts.size, 100) + + val ids100to200 = res100to200.associatedAlerts.stream().map { it.id }.collect(Collectors.toSet()) + val idsSet0to99 = res0to99.associatedAlerts.stream().map { it.id }.collect(Collectors.toSet()) + val idsSet200to300 = res200to300.associatedAlerts.stream().map { it.id }.collect(Collectors.toSet()) + + Assert.assertTrue(idsSet0to99.all { it !in ids100to200 }) + Assert.assertTrue(idsSet0to99.all { it !in idsSet200to300 }) + Assert.assertTrue(ids100to200.all { it !in idsSet200to300 }) + } } diff --git a/alerting/src/test/kotlin/org/opensearch/alerting/resthandler/WorkflowRestApiIT.kt b/alerting/src/test/kotlin/org/opensearch/alerting/resthandler/WorkflowRestApiIT.kt index fe751de75..fc1e69569 100644 --- a/alerting/src/test/kotlin/org/opensearch/alerting/resthandler/WorkflowRestApiIT.kt +++ b/alerting/src/test/kotlin/org/opensearch/alerting/resthandler/WorkflowRestApiIT.kt @@ -1101,7 +1101,7 @@ class WorkflowRestApiIT : AlertingRestTestCase() { assertTrue( (workflowTriggerResults[andTrigger.id] as Map)["triggered"] as Boolean ) - val res = getWorkflowAlerts(workflowId, true) + val res = getWorkflowAlerts(workflowId = workflowId, getAssociatedAlerts = true) val getWorkflowAlerts = entityAsMap(res) Assert.assertTrue(getWorkflowAlerts.containsKey("alerts")) Assert.assertTrue(getWorkflowAlerts.containsKey("associatedAlerts")) @@ -1113,6 +1113,18 @@ class WorkflowRestApiIT : AlertingRestTestCase() { val associatedAlerts = getWorkflowAlerts["associatedAlerts"] as List> assertEquals(associatedAlerts.size, 2) + val res1 = getWorkflowAlerts(workflowId = workflowId, alertId = alerts[0]["id"].toString(), getAssociatedAlerts = true) + val getWorkflowAlerts1 = entityAsMap(res1) + Assert.assertTrue(getWorkflowAlerts1.containsKey("alerts")) + Assert.assertTrue(getWorkflowAlerts1.containsKey("associatedAlerts")) + val alerts1 = getWorkflowAlerts1["alerts"] as List> + assertEquals(alerts1.size, 1) + Assert.assertEquals(alerts1[0]["execution_id"], executionId) + Assert.assertEquals(alerts1[0]["workflow_id"], workflowId) + Assert.assertEquals(alerts1[0]["monitor_id"], "") + val associatedAlerts1 = getWorkflowAlerts1["associatedAlerts"] as List> + assertEquals(associatedAlerts1.size, 2) + val getAlertsRes = getAlerts() val getAlertsMap = getAlertsRes.asMap() Assert.assertTrue(getAlertsMap.containsKey("alerts")) @@ -1121,11 +1133,11 @@ class WorkflowRestApiIT : AlertingRestTestCase() { Assert.assertEquals(getAlertsAlerts[0]["execution_id"], executionId) Assert.assertEquals(getAlertsAlerts[0]["workflow_id"], workflowId) Assert.assertEquals(getAlertsAlerts[0]["monitor_id"], "") - Assert.assertEquals(getAlertsAlerts[0]["id"], alerts[0]["id"]) + Assert.assertEquals(getAlertsAlerts[0]["id"], alerts1[0]["id"]) - val ackRes = acknowledgeChainedAlerts(workflowId, alerts[0]["id"].toString()) + val ackRes = acknowledgeChainedAlerts(workflowId, alerts1[0]["id"].toString()) val acknowledgeChainedAlertsResponse = entityAsMap(ackRes) val acknowledged = acknowledgeChainedAlertsResponse["success"] as List - Assert.assertEquals(acknowledged[0], alerts[0]["id"]) + Assert.assertEquals(acknowledged[0], alerts1[0]["id"]) } } diff --git a/alerting/src/test/kotlin/org/opensearch/alerting/transport/AlertingSingleNodeTestCase.kt b/alerting/src/test/kotlin/org/opensearch/alerting/transport/AlertingSingleNodeTestCase.kt index 79c970d3b..46b656b94 100644 --- a/alerting/src/test/kotlin/org/opensearch/alerting/transport/AlertingSingleNodeTestCase.kt +++ b/alerting/src/test/kotlin/org/opensearch/alerting/transport/AlertingSingleNodeTestCase.kt @@ -283,19 +283,20 @@ abstract class AlertingSingleNodeTestCase : OpenSearchSingleNodeTestCase() { alertState: Alert.State? = Alert.State.ACTIVE, alertIndex: String? = "", associatedAlertsIndex: String? = "", + alertIds: List? = emptyList(), + table: Table? = Table("asc", "monitor_id", null, 100, 0, null), ): GetWorkflowAlertsResponse { - val table = Table("asc", "monitor_id", null, 100, 0, null) return client().execute( AlertingActions.GET_WORKFLOW_ALERTS_ACTION_TYPE, GetWorkflowAlertsRequest( - table = table, + table = table!!, severityLevel = "ALL", alertState = alertState!!.name, alertIndex = alertIndex, associatedAlertsIndex = associatedAlertsIndex, monitorIds = emptyList(), workflowIds = listOf(workflowId), - alertIds = emptyList(), + alertIds = alertIds, getAssociatedAlerts = getAssociatedAlerts!! ) ).get()