From 27eed03359b59cfccd2a72a11b33e0e54283c247 Mon Sep 17 00:00:00 2001 From: kkewwei Date: Tue, 19 Nov 2024 16:58:14 +0800 Subject: [PATCH] Coordinator can return partial results after the timeout when allow_partial_search_results is true Signed-off-by: kkewwei --- CHANGELOG.md | 1 + .../action/search/SearchRequest.java | 34 ++++++--- .../opensearch/action/search/SearchTask.java | 1 + .../action/search/SearchTransportService.java | 8 +-- .../rest/action/search/RestSearchAction.java | 4 +- .../org/opensearch/search/SearchService.java | 5 -- .../opensearch/search/query/QueryPhase.java | 6 -- .../AbstractSearchAsyncActionTests.java | 69 ++++++++++++++++++- .../action/search/SearchRequestTests.java | 22 ++++++ .../search/SearchTransportServiceTests.java | 36 ++++++++++ 10 files changed, 158 insertions(+), 28 deletions(-) create mode 100644 server/src/test/java/org/opensearch/action/search/SearchTransportServiceTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c07a9362e3e7..b974c5530894c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add vertical scaling and SoftReference for snapshot repository data cache ([#16489](https://github.com/opensearch-project/OpenSearch/pull/16489)) - Support prefix list for remote repository attributes([#16271](https://github.com/opensearch-project/OpenSearch/pull/16271)) - Add new configuration setting `synonym_analyzer`, to the `synonym` and `synonym_graph` filters, enabling the specification of a custom analyzer for reading the synonym file ([#16488](https://github.com/opensearch-project/OpenSearch/pull/16488)). +- Coordinator can return partial results after the timeout when allow_partial_search_results is true ([#16681](https://github.com/opensearch-project/OpenSearch/pull/16681)). ### Dependencies - Bump `com.google.cloud:google-cloud-core-http` from 2.23.0 to 2.47.0 ([#16504](https://github.com/opensearch-project/OpenSearch/pull/16504)) diff --git a/server/src/main/java/org/opensearch/action/search/SearchRequest.java b/server/src/main/java/org/opensearch/action/search/SearchRequest.java index c95f71202be4f..8fd67d4488851 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchRequest.java +++ b/server/src/main/java/org/opensearch/action/search/SearchRequest.java @@ -128,7 +128,7 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla // it's only been used in coordinator, so we don't need to serialize/deserialize it private long startTimeMills; - private float queryPhaseTimeoutPercentage; + private float queryPhaseTimeoutPercentage = 0.8f; public SearchRequest() { this.localClusterAlias = null; @@ -358,6 +358,10 @@ public ActionRequestValidationException validate() { validationException = addValidationError("using [point in time] is not allowed in a scroll context", validationException); } } + + if (queryPhaseTimeoutPercentage <= 0 || queryPhaseTimeoutPercentage >= 1) { + validationException = addValidationError("[queryPhaseTimeoutPercentage] must be in (0, 1)", validationException); + } return validationException; } @@ -722,21 +726,27 @@ public String pipeline() { return pipeline; } - public void setQueryPhaseTimeoutPercentage(float queryPhaseTimeoutPercentage) { if (source.timeout() == null) { - throw new IllegalArgumentException("timeout must be set before setting query phase timeout percentage"); + throw new IllegalArgumentException("timeout must be set before setting queryPhaseTimeoutPercentage"); } this.queryPhaseTimeoutPercentage = queryPhaseTimeoutPercentage; } - public float getQueryPhasePercentage() { - return queryPhaseTimeoutPercentage; - } - @Override public SearchTask createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { - return new SearchTask(id, type, action, this::buildDescription, parentTaskId, headers, cancelAfterTimeInterval, startTimeMills, source.timeout() != null? source.timeout().millis() : -1, queryPhaseTimeoutPercentage); + return new SearchTask( + id, + type, + action, + this::buildDescription, + parentTaskId, + headers, + cancelAfterTimeInterval, + startTimeMills, + (source != null && source.timeout() != null) ? source.timeout().millis() : -1, + queryPhaseTimeoutPercentage + ); } public final String buildDescription() { @@ -788,7 +798,8 @@ public boolean equals(Object o) { && ccsMinimizeRoundtrips == that.ccsMinimizeRoundtrips && Objects.equals(cancelAfterTimeInterval, that.cancelAfterTimeInterval) && Objects.equals(pipeline, that.pipeline) - && Objects.equals(phaseTook, that.phaseTook); + && Objects.equals(phaseTook, that.phaseTook) + && Objects.equals(queryPhaseTimeoutPercentage, that.queryPhaseTimeoutPercentage); } @Override @@ -810,7 +821,8 @@ public int hashCode() { absoluteStartMillis, ccsMinimizeRoundtrips, cancelAfterTimeInterval, - phaseTook + phaseTook, + queryPhaseTimeoutPercentage ); } @@ -855,6 +867,8 @@ public String toString() { + pipeline + ", phaseTook=" + phaseTook + + ", queryPhaseTimeoutPercentage=" + + queryPhaseTimeoutPercentage + "}"; } } diff --git a/server/src/main/java/org/opensearch/action/search/SearchTask.java b/server/src/main/java/org/opensearch/action/search/SearchTask.java index d59f9bfd9252a..7460095b41d26 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchTask.java +++ b/server/src/main/java/org/opensearch/action/search/SearchTask.java @@ -84,6 +84,7 @@ public SearchTask( this.descriptionSupplier = descriptionSupplier; this.startTimeMills = startTimeMills; this.timeoutMills = timeoutMills; + assert queryPhaseTimeoutPercentage > 0 && queryPhaseTimeoutPercentage <= 1; this.queryPhaseTimeoutPercentage = queryPhaseTimeoutPercentage; } diff --git a/server/src/main/java/org/opensearch/action/search/SearchTransportService.java b/server/src/main/java/org/opensearch/action/search/SearchTransportService.java index dbe8fa253c279..0ba8cdd5d3b94 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchTransportService.java +++ b/server/src/main/java/org/opensearch/action/search/SearchTransportService.java @@ -400,17 +400,17 @@ void sendExecuteMultiSearch(final MultiSearchRequest request, SearchTask task, f ); } - public TransportRequestOptions getTransportRequestOptions(SearchTask task, Consumer onFailure, boolean queryPhase) { - if (task.timeoutMills() > 0) { + static TransportRequestOptions getTransportRequestOptions(SearchTask task, Consumer onFailure, boolean queryPhase) { + if (task != null && task.timeoutMills() > 0) { long leftTimeMills; if (queryPhase) { - //it's costly in query phase. + // it's costly in query phase. leftTimeMills = task.queryPhaseTimeout() - (System.currentTimeMillis() - task.startTimeMills()); } else { leftTimeMills = task.timeoutMills() - (System.currentTimeMillis() - task.startTimeMills()); } if (leftTimeMills <= 0) { - onFailure.accept(new TaskCancelledException("failed to execute fetch phase, timeout exceeded")); + onFailure.accept(new TaskCancelledException("failed to execute fetch phase, timeout exceeded" + leftTimeMills + "ms")); return null; } else { return TransportRequestOptions.builder().withTimeout(leftTimeMills).build(); diff --git a/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java b/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java index 33a69ad489cc9..1d00a9d8dca47 100644 --- a/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java +++ b/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java @@ -225,7 +225,9 @@ public static void parseSearchRequest( searchRequest.setCancelAfterTimeInterval(request.paramAsTime("cancel_after_time_interval", null)); - searchRequest.setQueryPhaseTimeoutPercentage(request.paramAsFloat("query_phase_timeout_percentage", SearchRequest.DEFAULT_QUERY_PHASE_TIMEOUT_PERCENTAGE)); + searchRequest.setQueryPhaseTimeoutPercentage( + request.paramAsFloat("query_phase_timeout_percentage", SearchRequest.DEFAULT_QUERY_PHASE_TIMEOUT_PERCENTAGE) + ); } /** diff --git a/server/src/main/java/org/opensearch/search/SearchService.java b/server/src/main/java/org/opensearch/search/SearchService.java index 03f80142b6d21..e892a2f1a7620 100644 --- a/server/src/main/java/org/opensearch/search/SearchService.java +++ b/server/src/main/java/org/opensearch/search/SearchService.java @@ -876,11 +876,6 @@ public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, A final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest()); final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest)); runAsync(getExecutor(readerContext.indexShard()), () -> { - if (request.getShardSearchRequest().shardId().getId() == 1) { - try { - Thread.sleep(10000); - } catch (Exception e) {} - } try (SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, false)) { if (request.lastEmittedDoc() != null) { searchContext.scrollContext().lastEmittedDoc = request.lastEmittedDoc(); diff --git a/server/src/main/java/org/opensearch/search/query/QueryPhase.java b/server/src/main/java/org/opensearch/search/query/QueryPhase.java index 891f6f343d741..55b7c0bc5178d 100644 --- a/server/src/main/java/org/opensearch/search/query/QueryPhase.java +++ b/server/src/main/java/org/opensearch/search/query/QueryPhase.java @@ -143,12 +143,6 @@ public void execute(SearchContext searchContext) throws QueryPhaseExecutionExcep return; } - if (searchContext.request().shardId().getId() == 2) { - try { - Thread.sleep(10000); - } catch (Exception e) {} - } - if (LOGGER.isTraceEnabled()) { LOGGER.trace("{}", new SearchContextSourcePrinter(searchContext)); } diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index 27336e86e52b0..b4b4e08709538 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -36,11 +36,13 @@ import org.opensearch.action.OriginalIndices; import org.opensearch.action.support.IndicesOptions; import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.routing.GroupShardsIterator; import org.opensearch.common.UUIDs; import org.opensearch.common.collect.Tuple; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.AtomicArray; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.set.Sets; @@ -55,6 +57,7 @@ import org.opensearch.index.shard.ShardNotFoundException; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.internal.ShardSearchContextId; @@ -65,6 +68,7 @@ import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.ReceiveTimeoutTransportException; import org.opensearch.transport.Transport; import org.junit.After; import org.junit.Before; @@ -89,6 +93,9 @@ import java.util.function.BiFunction; import java.util.stream.IntStream; +import org.mockito.Mockito; + +import static org.opensearch.action.search.SearchTransportService.QUERY_ACTION_NAME; import static org.opensearch.tasks.TaskResourceTrackingService.TASK_RESOURCE_USAGE; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -138,6 +145,7 @@ private AbstractSearchAsyncAction createAction( false, expected, resourceUsage, + false, new SearchShardIterator(null, null, Collections.emptyList(), null) ); } @@ -151,6 +159,7 @@ private AbstractSearchAsyncAction createAction( final boolean catchExceptionWhenExecutePhaseOnShard, final AtomicLong expected, final TaskResourceUsage resourceUsage, + final boolean blockTheFirstQueryPhase, final SearchShardIterator... shards ) { @@ -179,7 +188,7 @@ private AbstractSearchAsyncAction createAction( .setNodeId(randomAlphaOfLengthBetween(1, 5)) .build(); threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, taskResourceInfo.toString()); - + AtomicBoolean firstShard = new AtomicBoolean(true); return new AbstractSearchAsyncAction( "test", logger, @@ -207,7 +216,13 @@ private AbstractSearchAsyncAction createAction( ) { @Override protected SearchPhase getNextPhase(final SearchPhaseResults results, SearchPhaseContext context) { - return null; + return new SearchPhase("test") { + @Override + public void run() { + listener.onResponse(new SearchResponse(null, null, 0, 0, 0, 0, null, null)); + assertingListener.onPhaseEnd(context, null); + } + }; } @Override @@ -218,6 +233,17 @@ protected void executePhaseOnShard( ) { if (failExecutePhaseOnShard) { listener.onFailure(new ShardNotFoundException(shardIt.shardId())); + } else if (blockTheFirstQueryPhase && firstShard.compareAndSet(true, false)) { + // Sleep and throw ReceiveTimeoutTransportException to simulate node blocked + try { + Thread.sleep(request.source().timeout().millis()); + } catch (InterruptedException e) {} + ; + DiscoveryNode node = Mockito.mock(DiscoveryNode.class); + Mockito.when(node.getName()).thenReturn("test_nodes"); + listener.onFailure( + new ReceiveTimeoutTransportException(node, QUERY_ACTION_NAME, "request_id [171] timed out after [413ms]") + ); } else { if (catchExceptionWhenExecutePhaseOnShard) { try { @@ -227,6 +253,7 @@ protected void executePhaseOnShard( } } else { listener.onResponse(new QuerySearchResult()); + } } } @@ -587,6 +614,7 @@ public void onFailure(Exception e) { false, new AtomicLong(), new TaskResourceUsage(randomLong(), randomLong()), + false, shards ); action.run(); @@ -635,6 +663,7 @@ public void onFailure(Exception e) { false, new AtomicLong(), new TaskResourceUsage(randomLong(), randomLong()), + false, shards ); action.run(); @@ -688,6 +717,7 @@ public void onFailure(Exception e) { catchExceptionWhenExecutePhaseOnShard, new AtomicLong(), new TaskResourceUsage(randomLong(), randomLong()), + false, shards ); action.run(); @@ -791,6 +821,41 @@ public void testOnPhaseListenersWithDfsType() throws InterruptedException { assertEquals(0, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); } + public void testExecutePhaseOnShardBlockAndRetrunPartialResult() { + // on shard is blocked in query phase + final Index index = new Index("test", UUID.randomUUID().toString()); + + final SearchShardIterator[] shards = IntStream.range(0, 2 + randomInt(4)) + .mapToObj(i -> new SearchShardIterator(null, new ShardId(index, i), List.of("n1"), null, null, null)) + .toArray(SearchShardIterator[]::new); + + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + searchRequest.source(new SearchSourceBuilder()); + long timeoutMills = 500; + searchRequest.source().timeout(new TimeValue(timeoutMills, TimeUnit.MILLISECONDS)); + searchRequest.setMaxConcurrentShardRequests(shards.length); + final AtomicBoolean successed = new AtomicBoolean(false); + long current = System.currentTimeMillis(); + + final ArraySearchPhaseResults queryResult = new ArraySearchPhaseResults<>(shards.length); + AbstractSearchAsyncAction action = createAction(searchRequest, queryResult, new ActionListener<>() { + @Override + public void onResponse(SearchResponse response) { + successed.set(true); + } + + @Override + public void onFailure(Exception e) { + successed.set(false); + } + }, false, false, false, new AtomicLong(), new TaskResourceUsage(randomLong(), randomLong()), true, shards); + action.run(); + long s = System.currentTimeMillis() - current; + assertTrue(s > timeoutMills); + assertTrue(successed.get()); + + } + private SearchDfsQueryThenFetchAsyncAction createSearchDfsQueryThenFetchAsyncAction( List searchRequestOperationsListeners ) { diff --git a/server/src/test/java/org/opensearch/action/search/SearchRequestTests.java b/server/src/test/java/org/opensearch/action/search/SearchRequestTests.java index acda1445bacbb..64b99ab4af7a5 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchRequestTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchRequestTests.java @@ -238,6 +238,15 @@ public void testValidate() throws IOException { assertEquals(1, validationErrors.validationErrors().size()); assertEquals("using [point in time] is not allowed in a scroll context", validationErrors.validationErrors().get(0)); } + + { + // queryPhaseTimeoutPercentage must be in (0, 1) + SearchRequest searchRequest = createSearchRequest().source(new SearchSourceBuilder().timeout(TimeValue.timeValueMillis(10))); + searchRequest.setQueryPhaseTimeoutPercentage(-1); + ActionRequestValidationException validationErrors = searchRequest.validate(); + assertNotNull(validationErrors); + assertEquals("[queryPhaseTimeoutPercentage] must be in (0, 1)", validationErrors.validationErrors().get(0)); + } } public void testCopyConstructor() throws IOException { @@ -261,6 +270,19 @@ public void testParseSearchRequestWithUnsupportedSearchType() throws IOException assertEquals("Unsupported search type [query_and_fetch]", exception.getMessage()); } + public void testParseSearchRequestWithTimeoutAndQueryPhaseTimeoutPercentage() throws IOException { + RestRequest restRequest = new FakeRestRequest(); + SearchRequest searchRequest = createSearchRequest().source(new SearchSourceBuilder()); + IntConsumer setSize = mock(IntConsumer.class); + restRequest.params().put("query_phase_timeout_percentage", "30"); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> RestSearchAction.parseSearchRequest(searchRequest, restRequest, null, namedWriteableRegistry, setSize) + ); + assertEquals("timeout must be set before setting queryPhaseTimeoutPercentage", exception.getMessage()); + } + public void testEqualsAndHashcode() throws IOException { checkEqualsAndHashCode(createSearchRequest(), SearchRequest::new, this::mutate); } diff --git a/server/src/test/java/org/opensearch/action/search/SearchTransportServiceTests.java b/server/src/test/java/org/opensearch/action/search/SearchTransportServiceTests.java new file mode 100644 index 0000000000000..e389ff4b50e5c --- /dev/null +++ b/server/src/test/java/org/opensearch/action/search/SearchTransportServiceTests.java @@ -0,0 +1,36 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.action.search; + +import org.opensearch.core.tasks.TaskCancelledException; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.TransportRequestOptions; + +public class SearchTransportServiceTests extends OpenSearchTestCase { + public void testGetTransportRequestOptions() { + SearchTask searchTask = new SearchTask(1, null, null, null, null, null, null, System.currentTimeMillis(), 1000, 0.8f); + TransportRequestOptions transportRequestOptions = SearchTransportService.getTransportRequestOptions(searchTask, e -> {}, true); + assertTrue(transportRequestOptions.timeout().millis() > 0); + + TransportRequestOptions transportRequestOptions1 = SearchTransportService.getTransportRequestOptions(searchTask, e -> {}, false); + assertTrue(transportRequestOptions.timeout().millis() < transportRequestOptions1.timeout().millis()); + + SearchTask searchTask1 = new SearchTask(1, null, null, null, null, null, null, System.currentTimeMillis(), 1, 0.8f); + + transportRequestOptions = SearchTransportService.getTransportRequestOptions(searchTask1, exception -> { + assertEquals(TaskCancelledException.class, exception.getClass()); + assertTrue(exception.getMessage().contains("failed to execute fetch phase, timeout exceeded")); + }, true); + assertNull(transportRequestOptions); + + searchTask = new SearchTask(1, null, null, null, null, null, null, System.currentTimeMillis(), 0, 0.8f); + transportRequestOptions = SearchTransportService.getTransportRequestOptions(searchTask, e -> {}, false); + assertEquals(TransportRequestOptions.EMPTY, transportRequestOptions); + } +}