Skip to content

Commit

Permalink
Call updateState when query is cancelled (#3139)
Browse files Browse the repository at this point in the history
* Call updateState when query is cancelled

Signed-off-by: Tomoyuki Morita <[email protected]>

* Fix code style

Signed-off-by: Tomoyuki Morita <[email protected]>

---------

Signed-off-by: Tomoyuki Morita <[email protected]>
  • Loading branch information
ykmr1224 authored Oct 30, 2024
1 parent 3b86612 commit 5716cab
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.asyncquery.model.QueryState;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;
import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher;
Expand Down Expand Up @@ -116,7 +117,11 @@ public String cancelQuery(String queryId, AsyncQueryRequestContext asyncQueryReq
Optional<AsyncQueryJobMetadata> asyncQueryJobMetadata =
asyncQueryJobMetadataStorageService.getJobMetadata(queryId);
if (asyncQueryJobMetadata.isPresent()) {
return sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata.get(), asyncQueryRequestContext);
String result =
sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata.get(), asyncQueryRequestContext);
asyncQueryJobMetadataStorageService.updateState(
asyncQueryJobMetadata.get(), QueryState.CANCELLED, asyncQueryRequestContext);
return result;
}
throw new AsyncQueryNotFoundException(String.format("QueryId: %s not found", queryId));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@
import java.util.Optional;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.asyncquery.model.QueryState;

public interface AsyncQueryJobMetadataStorageService {

void storeJobMetadata(
AsyncQueryJobMetadata asyncQueryJobMetadata,
AsyncQueryRequestContext asyncQueryRequestContext);

void updateState(
AsyncQueryJobMetadata asyncQueryJobMetadata,
QueryState newState,
AsyncQueryRequestContext asyncQueryRequestContext);

Optional<AsyncQueryJobMetadata> getJobMetadata(String jobId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.sql.spark.asyncquery;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
Expand Down Expand Up @@ -33,6 +34,7 @@
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.asyncquery.model.QueryState;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;
import org.opensearch.sql.spark.config.SparkSubmitParameterModifier;
Expand Down Expand Up @@ -109,7 +111,7 @@ void testCreateAsyncQuery() {
.getSparkExecutionEngineConfig(asyncQueryRequestContext);
verify(sparkQueryDispatcher, times(1))
.dispatch(expectedDispatchQueryRequest, asyncQueryRequestContext);
Assertions.assertEquals(QUERY_ID, createAsyncQueryResponse.getQueryId());
assertEquals(QUERY_ID, createAsyncQueryResponse.getQueryId());
}

@Test
Expand Down Expand Up @@ -153,8 +155,7 @@ void testGetAsyncQueryResultsWithJobNotFoundException() {
AsyncQueryNotFoundException.class,
() -> jobExecutorService.getAsyncQueryResults(EMR_JOB_ID, asyncQueryRequestContext));

Assertions.assertEquals(
"QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage());
assertEquals("QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage());
verifyNoInteractions(sparkQueryDispatcher);
verifyNoInteractions(sparkExecutionEngineConfigSupplier);
}
Expand All @@ -174,7 +175,7 @@ void testGetAsyncQueryResultsWithInProgressJob() {

Assertions.assertNull(asyncQueryExecutionResponse.getResults());
Assertions.assertNull(asyncQueryExecutionResponse.getSchema());
Assertions.assertEquals("PENDING", asyncQueryExecutionResponse.getStatus());
assertEquals("PENDING", asyncQueryExecutionResponse.getStatus());
verifyNoInteractions(sparkExecutionEngineConfigSupplier);
}

Expand All @@ -191,11 +192,10 @@ void testGetAsyncQueryResultsWithSuccessJob() throws IOException {
AsyncQueryExecutionResponse asyncQueryExecutionResponse =
jobExecutorService.getAsyncQueryResults(EMR_JOB_ID, asyncQueryRequestContext);

Assertions.assertEquals("SUCCESS", asyncQueryExecutionResponse.getStatus());
Assertions.assertEquals(1, asyncQueryExecutionResponse.getSchema().getColumns().size());
Assertions.assertEquals(
"1", asyncQueryExecutionResponse.getSchema().getColumns().get(0).getName());
Assertions.assertEquals(
assertEquals("SUCCESS", asyncQueryExecutionResponse.getStatus());
assertEquals(1, asyncQueryExecutionResponse.getSchema().getColumns().size());
assertEquals("1", asyncQueryExecutionResponse.getSchema().getColumns().get(0).getName());
assertEquals(
1,
((HashMap<String, String>) asyncQueryExecutionResponse.getResults().get(0).value())
.get("1"));
Expand All @@ -212,8 +212,7 @@ void testCancelJobWithJobNotFound() {
AsyncQueryNotFoundException.class,
() -> jobExecutorService.cancelQuery(EMR_JOB_ID, asyncQueryRequestContext));

Assertions.assertEquals(
"QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage());
assertEquals("QueryId: " + EMR_JOB_ID + " not found", asyncQueryNotFoundException.getMessage());
verifyNoInteractions(sparkQueryDispatcher);
verifyNoInteractions(sparkExecutionEngineConfigSupplier);
}
Expand All @@ -227,7 +226,9 @@ void testCancelJob() {

String jobId = jobExecutorService.cancelQuery(EMR_JOB_ID, asyncQueryRequestContext);

Assertions.assertEquals(EMR_JOB_ID, jobId);
assertEquals(EMR_JOB_ID, jobId);
verify(asyncQueryJobMetadataStorageService)
.updateState(any(), eq(QueryState.CANCELLED), eq(asyncQueryRequestContext));
verifyNoInteractions(sparkExecutionEngineConfigSupplier);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryRequestContext;
import org.opensearch.sql.spark.asyncquery.model.QueryState;
import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.execution.xcontent.AsyncQueryJobMetadataXContentSerializer;
Expand Down Expand Up @@ -39,6 +40,14 @@ public void storeJobMetadata(
OpenSearchStateStoreUtil.getIndexName(asyncQueryJobMetadata.getDatasourceName()));
}

@Override
public void updateState(
AsyncQueryJobMetadata asyncQueryJobMetadata,
QueryState newState,
AsyncQueryRequestContext asyncQueryRequestContext) {
// NoOp since AsyncQueryJobMetadata record does not store state now
}

private String mapIdToDocumentId(String id) {
return "qid" + id;
}
Expand Down

0 comments on commit 5716cab

Please sign in to comment.