diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java index 4396b45898..dbd6411998 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -25,6 +25,7 @@ import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; import org.opensearch.sql.spark.dispatcher.model.JobType; import org.opensearch.sql.spark.leasemanager.LeaseManager; +import org.opensearch.sql.spark.leasemanager.model.LeaseRequest; import org.opensearch.sql.spark.metrics.MetricsService; import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; import org.opensearch.sql.spark.response.JobExecutionResponseReader; @@ -75,6 +76,14 @@ public String cancelJob( return asyncQueryJobMetadata.getQueryId(); } + /** + * This method allows RefreshQueryHandler to override the job type when calling + * leaseManager.borrow. + */ + protected void borrow(String datasource) { + leaseManager.borrow(new LeaseRequest(JobType.BATCH, datasource)); + } + @Override public DispatchQueryResponse submit( DispatchQueryRequest dispatchQueryRequest, DispatchQueryContext context) { @@ -82,6 +91,8 @@ public DispatchQueryResponse submit( Map tags = context.getTags(); DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata(); + this.borrow(dispatchQueryRequest.getDatasource()); + tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); StartJobRequest startJobRequest = new StartJobRequest( diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java index 52cd863081..659166e8a6 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/RefreshQueryHandler.java @@ -71,10 +71,14 @@ public String cancelJob( return asyncQueryJobMetadata.getQueryId(); } + @Override + protected void borrow(String datasource) { + leaseManager.borrow(new LeaseRequest(JobType.REFRESH, datasource)); + } + @Override public DispatchQueryResponse submit( DispatchQueryRequest dispatchQueryRequest, DispatchQueryContext context) { - leaseManager.borrow(new LeaseRequest(JobType.REFRESH, dispatchQueryRequest.getDatasource())); DispatchQueryResponse resp = super.submit(dispatchQueryRequest, context); DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata(); diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index d4a6b544c4..5ef8343dcc 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -77,6 +77,7 @@ import org.opensearch.sql.spark.flint.IndexDMLResultStorageService; import org.opensearch.sql.spark.flint.operation.FlintIndexOpFactory; import org.opensearch.sql.spark.leasemanager.LeaseManager; +import org.opensearch.sql.spark.leasemanager.model.LeaseRequest; import org.opensearch.sql.spark.metrics.MetricsService; import org.opensearch.sql.spark.parameter.SparkParameterComposerCollection; import org.opensearch.sql.spark.parameter.SparkSubmitParametersBuilderProvider; @@ -137,6 +138,7 @@ public class AsyncQueryCoreIntegTest { @Captor ArgumentCaptor flintIndexOptionsArgumentCaptor; @Captor ArgumentCaptor startJobRunRequestArgumentCaptor; @Captor ArgumentCaptor createSessionRequestArgumentCaptor; + @Captor ArgumentCaptor leaseRequestArgumentCaptor; AsyncQueryExecutorService asyncQueryExecutorService; @@ -267,7 +269,8 @@ public void createVacuumIndexQuery() { assertEquals(SESSION_ID, response.getSessionId()); verifyGetQueryIdCalled(); verifyGetSessionIdCalled(); - verify(leaseManager).borrow(any()); + verify(leaseManager).borrow(leaseRequestArgumentCaptor.capture()); + assertEquals(JobType.INTERACTIVE, leaseRequestArgumentCaptor.getValue().getJobType()); verifyStartJobRunCalled(); verifyStoreJobMetadataCalled(JOB_ID, QueryState.WAITING, JobType.INTERACTIVE); } @@ -356,11 +359,38 @@ public void createStreamingQuery() { assertEquals(QUERY_ID, response.getQueryId()); assertNull(response.getSessionId()); verifyGetQueryIdCalled(); - verify(leaseManager).borrow(any()); + verify(leaseManager).borrow(leaseRequestArgumentCaptor.capture()); + assertEquals(JobType.STREAMING, leaseRequestArgumentCaptor.getValue().getJobType()); verifyStartJobRunCalled(); verifyStoreJobMetadataCalled(JOB_ID, QueryState.WAITING, JobType.STREAMING); } + @Test + public void createBatchQuery() { + givenSparkExecutionEngineConfigIsSupplied(); + givenValidDataSourceMetadataExist(); + when(queryIdProvider.getQueryId(any(), eq(asyncQueryRequestContext))).thenReturn(QUERY_ID); + when(awsemrServerless.startJobRun(any())) + .thenReturn(new StartJobRunResult().withApplicationId(APPLICATION_ID).withJobRunId(JOB_ID)); + + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "CREATE INDEX index_name ON table_name(l_orderkey, l_quantity)" + + " WITH (auto_refresh = false)", + DATASOURCE_NAME, + LangType.SQL), + asyncQueryRequestContext); + + assertEquals(QUERY_ID, response.getQueryId()); + assertNull(response.getSessionId()); + verifyGetQueryIdCalled(); + verify(leaseManager).borrow(leaseRequestArgumentCaptor.capture()); + assertEquals(JobType.BATCH, leaseRequestArgumentCaptor.getValue().getJobType()); + verifyStartJobRunCalled(); + verifyStoreJobMetadataCalled(JOB_ID, QueryState.WAITING, JobType.BATCH); + } + private void verifyStartJobRunCalled() { verify(awsemrServerless).startJobRun(startJobRunRequestArgumentCaptor.capture()); StartJobRunRequest startJobRunRequest = startJobRunRequestArgumentCaptor.getValue(); @@ -413,7 +443,8 @@ public void createRefreshQuery() { assertEquals(QUERY_ID, response.getQueryId()); assertNull(response.getSessionId()); verifyGetQueryIdCalled(); - verify(leaseManager).borrow(any()); + verify(leaseManager).borrow(leaseRequestArgumentCaptor.capture()); + assertEquals(JobType.REFRESH, leaseRequestArgumentCaptor.getValue().getJobType()); verifyStartJobRunCalled(); verifyStoreJobMetadataCalled(JOB_ID, QueryState.WAITING, JobType.REFRESH); } @@ -439,7 +470,8 @@ public void createInteractiveQuery() { assertEquals(SESSION_ID, response.getSessionId()); verifyGetQueryIdCalled(); verifyGetSessionIdCalled(); - verify(leaseManager).borrow(any()); + verify(leaseManager).borrow(leaseRequestArgumentCaptor.capture()); + assertEquals(JobType.INTERACTIVE, leaseRequestArgumentCaptor.getValue().getJobType()); verifyStartJobRunCalled(); verifyStoreJobMetadataCalled(JOB_ID, QueryState.WAITING, JobType.INTERACTIVE); } diff --git a/async-query/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java b/async-query/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java index 8d57198277..18dce7b7b2 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java @@ -237,7 +237,8 @@ private void createIndex(String indexName) { } } - private long count(String indexName, QueryBuilder query) { + @VisibleForTesting + public long count(String indexName, QueryBuilder query) { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.query(query); searchSourceBuilder.size(0); diff --git a/async-query/src/main/java/org/opensearch/sql/spark/leasemanager/DefaultLeaseManager.java b/async-query/src/main/java/org/opensearch/sql/spark/leasemanager/DefaultLeaseManager.java index 375fa7b11e..db8ca1ad2b 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/leasemanager/DefaultLeaseManager.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/leasemanager/DefaultLeaseManager.java @@ -92,7 +92,8 @@ public String description() { @Override public boolean test(LeaseRequest leaseRequest) { - if (leaseRequest.getJobType() == JobType.INTERACTIVE) { + if (leaseRequest.getJobType() != JobType.REFRESH + && leaseRequest.getJobType() != JobType.STREAMING) { return true; } return activeRefreshJobCount(stateStore, ALL_DATASOURCE).get() < refreshJobLimit(); diff --git a/async-query/src/test/java/org/opensearch/sql/spark/leasemanager/DefaultLeaseManagerTest.java b/async-query/src/test/java/org/opensearch/sql/spark/leasemanager/DefaultLeaseManagerTest.java index 558f7f7b3a..a7ea6aa22f 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/leasemanager/DefaultLeaseManagerTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/leasemanager/DefaultLeaseManagerTest.java @@ -5,7 +5,9 @@ package org.opensearch.sql.spark.leasemanager; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -23,19 +25,36 @@ class DefaultLeaseManagerTest { @Mock private StateStore stateStore; @Test - public void concurrentSessionRuleOnlyApplyToInteractiveQuery() { - assertTrue( - new DefaultLeaseManager.ConcurrentSessionRule(settings, stateStore) - .test(new LeaseRequest(JobType.BATCH, "mys3"))); - assertTrue( - new DefaultLeaseManager.ConcurrentSessionRule(settings, stateStore) - .test(new LeaseRequest(JobType.STREAMING, "mys3"))); + public void leaseManagerRejectsJobs() { + when(stateStore.count(any(), any())).thenReturn(3L); + when(settings.getSettingValue(any())).thenReturn(3); + DefaultLeaseManager defaultLeaseManager = new DefaultLeaseManager(settings, stateStore); + + defaultLeaseManager.borrow(getLeaseRequest(JobType.BATCH)); + assertThrows( + ConcurrencyLimitExceededException.class, + () -> defaultLeaseManager.borrow(getLeaseRequest(JobType.INTERACTIVE))); + assertThrows( + ConcurrencyLimitExceededException.class, + () -> defaultLeaseManager.borrow(getLeaseRequest(JobType.STREAMING))); + assertThrows( + ConcurrencyLimitExceededException.class, + () -> defaultLeaseManager.borrow(getLeaseRequest(JobType.REFRESH))); } @Test - public void concurrentRefreshRuleOnlyNotAppliedToInteractiveQuery() { - assertTrue( - new DefaultLeaseManager.ConcurrentRefreshJobRule(settings, stateStore) - .test(new LeaseRequest(JobType.INTERACTIVE, "mys3"))); + public void leaseManagerAcceptsJobs() { + when(stateStore.count(any(), any())).thenReturn(2L); + when(settings.getSettingValue(any())).thenReturn(3); + DefaultLeaseManager defaultLeaseManager = new DefaultLeaseManager(settings, stateStore); + + defaultLeaseManager.borrow(getLeaseRequest(JobType.BATCH)); + defaultLeaseManager.borrow(getLeaseRequest(JobType.INTERACTIVE)); + defaultLeaseManager.borrow(getLeaseRequest(JobType.STREAMING)); + defaultLeaseManager.borrow(getLeaseRequest(JobType.REFRESH)); + } + + private LeaseRequest getLeaseRequest(JobType jobType) { + return new LeaseRequest(jobType, "mys3"); } }