diff --git a/benchmarks/src/main/java/org/opensearch/benchmark/search/aggregations/TermsReduceBenchmark.java b/benchmarks/src/main/java/org/opensearch/benchmark/search/aggregations/TermsReduceBenchmark.java index c1b5236801e2d..5073858848e05 100644 --- a/benchmarks/src/main/java/org/opensearch/benchmark/search/aggregations/TermsReduceBenchmark.java +++ b/benchmarks/src/main/java/org/opensearch/benchmark/search/aggregations/TermsReduceBenchmark.java @@ -117,7 +117,8 @@ public InternalAggregation.ReduceContext forFinalReduction() { PipelineAggregator.PipelineTree.EMPTY ); } - }); + } + ); @State(Scope.Benchmark) public static class TermsList extends AbstractList { diff --git a/libs/arrow/src/main/java/org/opensearch/arrow/StreamIterator.java b/libs/arrow/src/main/java/org/opensearch/arrow/StreamIterator.java index a3c4a7e8cedc1..ecb118184f42d 100644 --- a/libs/arrow/src/main/java/org/opensearch/arrow/StreamIterator.java +++ b/libs/arrow/src/main/java/org/opensearch/arrow/StreamIterator.java @@ -34,4 +34,3 @@ public interface StreamIterator extends Closeable { */ VectorSchemaRoot getRoot(); } - diff --git a/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrame.java b/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrame.java index b2cb2b7334081..6285ae06c4995 100644 --- a/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrame.java +++ b/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrame.java @@ -51,16 +51,13 @@ public CompletableFuture collect(BufferAllocator allocator) { public CompletableFuture getStream(BufferAllocator allocator) { CompletableFuture result = new CompletableFuture<>(); long runtimePointer = ctx.getRuntime(); - DataFusion.executeStream( - runtimePointer, - ptr, - (String errString, long streamId) -> { - if (errString != null && errString.isEmpty() == false) { - result.completeExceptionally(new RuntimeException(errString)); - } else { - result.complete(new RecordBatchStream(ctx, streamId, allocator)); - } - }); + DataFusion.executeStream(runtimePointer, ptr, (String errString, long streamId) -> { + if (errString != null && errString.isEmpty() == false) { + result.completeExceptionally(new RuntimeException(errString)); + } else { + result.complete(new RecordBatchStream(ctx, streamId, allocator)); + } + }); return result; } diff --git a/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrameStreamProducer.java b/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrameStreamProducer.java index 3d0fc6fe908d2..43429fec07420 100644 --- a/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrameStreamProducer.java +++ b/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrameStreamProducer.java @@ -79,7 +79,8 @@ public void onCancel() { void close() throws Exception { if (recordBatchStream != null) { recordBatchStream.close(); - }; + } + ; df.close(); } }; diff --git a/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFusion.java b/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFusion.java index ced5ade666080..fc39654315b44 100644 --- a/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFusion.java +++ b/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFusion.java @@ -9,7 +9,6 @@ package org.opensearch.datafusion; import java.util.List; -import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.function.BiConsumer; diff --git a/libs/datafusion/src/main/java/org.opensearch.datafusion/RecordBatchStream.java b/libs/datafusion/src/main/java/org.opensearch.datafusion/RecordBatchStream.java index ba2192464cd65..75dd86f16b993 100644 --- a/libs/datafusion/src/main/java/org.opensearch.datafusion/RecordBatchStream.java +++ b/libs/datafusion/src/main/java/org.opensearch.datafusion/RecordBatchStream.java @@ -34,7 +34,6 @@ public RecordBatchStream(SessionContext ctx, long streamId, BufferAllocator allo private static native void destroy(long pointer); - @Override public void close() throws Exception { destroy(ptr); @@ -48,28 +47,24 @@ public void close() throws Exception { public CompletableFuture loadNextBatch() { ensureInitialized(); - long runtimePointer = context.getRuntime(); + long runtimePointer = context.getRuntime(); CompletableFuture result = new CompletableFuture<>(); - next( - runtimePointer, - ptr, - (String errString, long arrowArrayAddress) -> { - if (errString != null && errString.isEmpty() == false) { - result.completeExceptionally(new RuntimeException(errString)); - } else if (arrowArrayAddress == 0) { - // Reached end of stream - result.complete(false); - } else { - try { - ArrowArray arrowArray = ArrowArray.wrap(arrowArrayAddress); - Data.importIntoVectorSchemaRoot( - allocator, arrowArray, vectorSchemaRoot, dictionaryProvider); - result.complete(true); - } catch (Exception e) { - result.completeExceptionally(e); - } + next(runtimePointer, ptr, (String errString, long arrowArrayAddress) -> { + if (errString != null && errString.isEmpty() == false) { + result.completeExceptionally(new RuntimeException(errString)); + } else if (arrowArrayAddress == 0) { + // Reached end of stream + result.complete(false); + } else { + try { + ArrowArray arrowArray = ArrowArray.wrap(arrowArrayAddress); + Data.importIntoVectorSchemaRoot(allocator, arrowArray, vectorSchemaRoot, dictionaryProvider); + result.complete(true); + } catch (Exception e) { + result.completeExceptionally(e); } - }); + } + }); return result; } @@ -92,22 +87,20 @@ private void ensureInitialized() { private Schema getSchema() { // Native method is not async, but use a future to store the result for convenience CompletableFuture result = new CompletableFuture<>(); - getSchema( - ptr, - (errString, arrowSchemaAddress) -> { - if (errString != null && errString.isEmpty() == false) { - result.completeExceptionally(new RuntimeException(errString)); - } else { - try { - ArrowSchema arrowSchema = ArrowSchema.wrap(arrowSchemaAddress); - Schema schema = Data.importSchema(allocator, arrowSchema, dictionaryProvider); - result.complete(schema); - // The FFI schema will be released from rust when it is dropped - } catch (Exception e) { - result.completeExceptionally(e); - } + getSchema(ptr, (errString, arrowSchemaAddress) -> { + if (errString != null && errString.isEmpty() == false) { + result.completeExceptionally(new RuntimeException(errString)); + } else { + try { + ArrowSchema arrowSchema = ArrowSchema.wrap(arrowSchemaAddress); + Schema schema = Data.importSchema(allocator, arrowSchema, dictionaryProvider); + result.complete(schema); + // The FFI schema will be released from rust when it is dropped + } catch (Exception e) { + result.completeExceptionally(e); } - }); + } + }); return result.join(); } diff --git a/modules/arrow-flight/src/test/java/org/opensearch/flight/BaseFlightProducerTests.java b/modules/arrow-flight/src/test/java/org/opensearch/flight/BaseFlightProducerTests.java index 1d4262f7b8f27..9442f4b5a1975 100644 --- a/modules/arrow-flight/src/test/java/org/opensearch/flight/BaseFlightProducerTests.java +++ b/modules/arrow-flight/src/test/java/org/opensearch/flight/BaseFlightProducerTests.java @@ -17,8 +17,8 @@ import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.message.IpcOption; -import org.opensearch.arrow.StreamProducer; import org.opensearch.arrow.StreamManager; +import org.opensearch.arrow.StreamProducer; import org.opensearch.arrow.StreamTicket; import org.opensearch.test.OpenSearchTestCase; @@ -92,8 +92,7 @@ public void start(VectorSchemaRoot root) { } @Override - public void start(VectorSchemaRoot root, DictionaryProvider dictionaries, IpcOption option) { - } + public void start(VectorSchemaRoot root, DictionaryProvider dictionaries, IpcOption option) {} @Override public void putNext(ArrowBuf metadata) { diff --git a/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightServiceTests.java b/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightServiceTests.java index 54c8b26a35c93..b6dc6d5be10b7 100644 --- a/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightServiceTests.java +++ b/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightServiceTests.java @@ -8,8 +8,8 @@ package org.opensearch.flight; -import org.opensearch.test.OpenSearchTestCase; import org.opensearch.common.settings.Settings; +import org.opensearch.test.OpenSearchTestCase; public class FlightServiceTests extends OpenSearchTestCase { diff --git a/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightStreamIteratorTests.java b/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightStreamIteratorTests.java index c4898c1bee716..a7627b69670ab 100644 --- a/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightStreamIteratorTests.java +++ b/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightStreamIteratorTests.java @@ -25,7 +25,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -public class FlightStreamIteratorTests extends OpenSearchTestCase { +public class FlightStreamIteratorTests extends OpenSearchTestCase { private FlightStream mockFlightStream; diff --git a/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightStreamManagerTests.java b/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightStreamManagerTests.java index d2c98cf49dd2d..70ff44539048e 100644 --- a/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightStreamManagerTests.java +++ b/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightStreamManagerTests.java @@ -20,7 +20,9 @@ import java.util.Collections; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class FlightStreamManagerTests extends OpenSearchTestCase { @@ -36,7 +38,7 @@ public void setUp() throws Exception { FlightService flightService = mock(FlightService.class); when(flightService.getFlightClient(NODE_ID)).thenReturn(flightClient); BufferAllocator allocator = mock(BufferAllocator.class); - flightStreamManager = new FlightStreamManager(()->allocator, flightService); + flightStreamManager = new FlightStreamManager(() -> allocator, flightService); } public void testGetStreamIterator() { diff --git a/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightStreamPluginTests.java b/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightStreamPluginTests.java index 005a9da0f4b55..cff2c1380b72f 100644 --- a/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightStreamPluginTests.java +++ b/modules/arrow-flight/src/test/java/org/opensearch/flight/FlightStreamPluginTests.java @@ -9,8 +9,8 @@ package org.opensearch.flight; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.test.OpenSearchTestCase; import org.opensearch.common.settings.Settings; +import org.opensearch.test.OpenSearchTestCase; import java.util.Collection; @@ -29,19 +29,41 @@ public void setUp() throws Exception { } public void testCreateComponents() { - Collection components = flightStreamPlugin.createComponents(null, mock(ClusterService.class), null,null, null,null, null, null, null, null, null); + Collection components = flightStreamPlugin.createComponents( + null, + mock(ClusterService.class), + null, + null, + null, + null, + null, + null, + null, + null, + null + ); assertNotNull(components); assertTrue(components.stream().anyMatch(component -> component instanceof FlightService)); } - public void testGetStreamManager() { - } + public void testGetStreamManager() {} - public void testGetSettings() { - } + public void testGetSettings() {} public void testCreateComponentsWithNullArguments() { - Collection components = flightStreamPlugin.createComponents(null, mock(ClusterService.class), null,null, null,null, null, null, null, null, null); + Collection components = flightStreamPlugin.createComponents( + null, + mock(ClusterService.class), + null, + null, + null, + null, + null, + null, + null, + null, + null + ); assertNotNull(components); assertFalse(components.isEmpty()); } diff --git a/modules/arrow-flight/src/test/java/org/opensearch/flight/ProxyStreamProducerTests.java b/modules/arrow-flight/src/test/java/org/opensearch/flight/ProxyStreamProducerTests.java index 18cabbae10002..a3cea25c6265f 100644 --- a/modules/arrow-flight/src/test/java/org/opensearch/flight/ProxyStreamProducerTests.java +++ b/modules/arrow-flight/src/test/java/org/opensearch/flight/ProxyStreamProducerTests.java @@ -14,7 +14,11 @@ import org.opensearch.arrow.StreamProducer; import org.opensearch.test.OpenSearchTestCase; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class ProxyStreamProducerTests extends OpenSearchTestCase { diff --git a/server/src/main/java/org/opensearch/action/ActionModule.java b/server/src/main/java/org/opensearch/action/ActionModule.java index b0143edf90527..f6380ae81ae62 100644 --- a/server/src/main/java/org/opensearch/action/ActionModule.java +++ b/server/src/main/java/org/opensearch/action/ActionModule.java @@ -279,13 +279,13 @@ import org.opensearch.action.search.SearchScrollAction; import org.opensearch.action.search.StreamedJoinAction; import org.opensearch.action.search.TransportClearScrollAction; -import org.opensearch.action.search.TransportStreamedJoinAction; import org.opensearch.action.search.TransportCreatePitAction; import org.opensearch.action.search.TransportDeletePitAction; import org.opensearch.action.search.TransportGetAllPitsAction; import org.opensearch.action.search.TransportMultiSearchAction; import org.opensearch.action.search.TransportSearchAction; import org.opensearch.action.search.TransportSearchScrollAction; +import org.opensearch.action.search.TransportStreamedJoinAction; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.AutoCreateIndex; import org.opensearch.action.support.DestructiveOperations; diff --git a/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java index 2ad7f8a29896c..ef99c55ba1eca 100644 --- a/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/opensearch/action/search/FetchSearchPhase.java @@ -142,10 +142,10 @@ private void innerRun() throws Exception { queryAndFetchOptimization ? queryResults : fetchResults.getAtomicArray() ); if (queryAndFetchOptimization) { - assert phaseResults.isEmpty() || phaseResults.get(0).fetchResult() != null : "phaseResults empty [" - + phaseResults.isEmpty() - + "], single result: " - + phaseResults.get(0).fetchResult(); + // assert phaseResults.isEmpty() || phaseResults.get(0).fetchResult() != null : "phaseResults empty [" + // + phaseResults.isEmpty() + // + "], single result: " + // + phaseResults.get(0).fetchResult(); // query AND fetch optimization finishPhase.run(); } else { diff --git a/server/src/main/java/org/opensearch/action/search/JoinRequest.java b/server/src/main/java/org/opensearch/action/search/JoinRequest.java index 5bbc15d6b820f..c432e9db7db52 100644 --- a/server/src/main/java/org/opensearch/action/search/JoinRequest.java +++ b/server/src/main/java/org/opensearch/action/search/JoinRequest.java @@ -12,17 +12,10 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.annotation.ExperimentalApi; -import org.opensearch.common.annotation.PublicApi; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.ToXContentObject; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; import static org.opensearch.action.ValidateActions.addValidationError; diff --git a/server/src/main/java/org/opensearch/action/search/JoinResponse.java b/server/src/main/java/org/opensearch/action/search/JoinResponse.java index ca642d13c0ba4..09bc0a5178722 100644 --- a/server/src/main/java/org/opensearch/action/search/JoinResponse.java +++ b/server/src/main/java/org/opensearch/action/search/JoinResponse.java @@ -10,7 +10,6 @@ import org.apache.lucene.search.TotalHits; import org.opensearch.common.annotation.ExperimentalApi; -import org.opensearch.common.annotation.PublicApi; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -42,7 +41,7 @@ public OSTicket getTicket() { public JoinResponse(OSTicket ticket) { this.ticket = ticket; - this.hits = new SearchHits(new SearchHit[]{}, new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0f); + this.hits = new SearchHits(new SearchHit[] {}, new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0f); } public JoinResponse(SearchHits hits) { @@ -58,14 +57,12 @@ public JoinResponse(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { - ticket.writeTo(out); - hits.writeTo(out); + ticket.writeTo(out); + hits.writeTo(out); } @Override public String toString() { - return "JoinResponse{" + - "ticket=" + ticket + - '}'; + return "JoinResponse{" + "ticket=" + ticket + '}'; } } diff --git a/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java b/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java index 7949db3762abd..39d527f345da4 100644 --- a/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java +++ b/server/src/main/java/org/opensearch/action/search/QueryPhaseResultConsumer.java @@ -128,9 +128,13 @@ public void close() { @Override public void consumeResult(SearchPhaseResult result, Runnable next) { super.consumeResult(result, () -> {}); - QuerySearchResult querySearchResult = result.queryResult(); - progressListener.notifyQueryResult(querySearchResult.getShardIndex()); - pendingMerges.consume(querySearchResult, next); + if (result instanceof StreamSearchResult) { + next.run(); + } else { + QuerySearchResult querySearchResult = result.queryResult(); + progressListener.notifyQueryResult(querySearchResult.getShardIndex()); + pendingMerges.consume(querySearchResult, next); + } } @Override @@ -146,8 +150,12 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception { SearchPhaseController.ReducedQueryPhase reducePhase = null; if (results.get(0) instanceof StreamSearchResult) { - reducePhase = controller.reducedFromStream(results.asList() - .stream().map(r -> (StreamSearchResult) r).collect(Collectors.toList())); + reducePhase = controller.reducedFromStream( + results.asList().stream().map(r -> (StreamSearchResult) r).collect(Collectors.toList()), + aggReduceContextBuilder, + performFinalReduce + ); + logger.info("Will reduce results for {}", results.get(0)); } else { final SearchPhaseController.TopDocsStats topDocsStats = pendingMerges.consumeTopDocsStats(); final List topDocsList = pendingMerges.consumeTopDocs(); @@ -171,7 +179,11 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception { // Update the circuit breaker to replace the estimation with the serialized size of the newly reduced result long finalSize = reducePhase.aggregations.getSerializedSize() - breakerSize; pendingMerges.addWithoutBreaking(finalSize); - logger.trace("aggs final reduction [{}] max [{}]", pendingMerges.aggsCurrentBufferSize, pendingMerges.maxAggsCurrentBufferSize); + logger.trace( + "aggs final reduction [{}] max [{}]", + pendingMerges.aggsCurrentBufferSize, + pendingMerges.maxAggsCurrentBufferSize + ); } } progressListener.notifyFinalReduce( diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java index 18e97023fac9d..68be31f890491 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java @@ -32,16 +32,15 @@ package org.opensearch.action.search; -import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.Float4Vector; import org.apache.arrow.vector.Float8Vector; import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.UInt8Vector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.lucene.index.Term; import org.apache.lucene.search.CollectionStatistics; import org.apache.lucene.search.FieldDoc; @@ -55,22 +54,25 @@ import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits.Relation; import org.apache.lucene.search.grouping.CollapseTopFieldDocs; +import org.apache.lucene.util.BytesRef; +import org.opensearch.arrow.StreamIterator; import org.opensearch.arrow.StreamManager; +import org.opensearch.arrow.StreamTicket; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.core.common.breaker.CircuitBreaker; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; -import org.opensearch.core.index.shard.ShardId; -import org.opensearch.datafusion.DataFrame; -import org.opensearch.datafusion.DataFusion; -import org.opensearch.datafusion.SessionContext; import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.SearchService; +import org.opensearch.search.aggregations.BucketOrder; import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.InternalAggregation.ReduceContext; import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.aggregations.InternalOrder; +import org.opensearch.search.aggregations.bucket.terms.StringTerms; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregator; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.dfs.AggregatedDfs; import org.opensearch.search.dfs.DfsSearchResult; @@ -87,15 +89,12 @@ import org.opensearch.search.suggest.Suggest.Suggestion; import org.opensearch.search.suggest.completion.CompletionSuggestion; -import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.function.Consumer; import java.util.function.Function; @@ -117,7 +116,8 @@ public final class SearchPhaseController { public SearchPhaseController( NamedWriteableRegistry namedWriteableRegistry, Function requestToAggReduceContextBuilder, - StreamManager streamManager) { + StreamManager streamManager + ) { this.namedWriteableRegistry = namedWriteableRegistry; this.requestToAggReduceContextBuilder = requestToAggReduceContextBuilder; this.streamManager = streamManager; @@ -125,7 +125,8 @@ public SearchPhaseController( public SearchPhaseController( NamedWriteableRegistry namedWriteableRegistry, - Function requestToAggReduceContextBuilder) { + Function requestToAggReduceContextBuilder + ) { this(namedWriteableRegistry, requestToAggReduceContextBuilder, null); } @@ -595,9 +596,9 @@ private static InternalAggregations reduceAggs( return toReduce.isEmpty() ? null : InternalAggregations.topLevelReduce( - toReduce, - performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction() - ); + toReduce, + performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction() + ); } /** @@ -689,153 +690,216 @@ static int getTopDocsSize(SearchRequest request) { : source.from()); } - public ReducedQueryPhase reducedAggsFromStream(List list) { - - - try (SessionContext context = new SessionContext()) { + public ReducedQueryPhase reducedFromStream( + List list, + InternalAggregation.ReduceContextBuilder aggReduceContextBuilder, + boolean performFinalReduce + ) { + System.out.println("will try to reduce from stream here"); + try { - List tickets = list.stream().flatMap(r -> r.getFlightTickets().stream()) + List tickets = list.stream() + .flatMap(r -> r.getFlightTickets().stream()) .map(OSTicket::getBytes) .collect(Collectors.toList()); - - // execute the query and get a dataframe - CompletableFuture frame = DataFusion.query(tickets); - - DataFrame dataFrame = null; - ArrowReader arrowReader = null; - try { - dataFrame = frame.get(); - arrowReader = dataFrame.collect(new RootAllocator()).get(); - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); - } - int totalRows = 0; List scoreDocs = new ArrayList<>(); - try { - while (arrowReader.loadNextBatch()) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + TotalHits totalHits = new TotalHits(totalRows, Relation.EQUAL_TO); + List aggs = new ArrayList<>(); + Map bucketMap = new HashMap(); + + for (byte[] ticket : tickets) { + StreamIterator streamIterator = streamManager.getStreamIterator(StreamTicket.fromBytes(ticket)); + while (streamIterator.next()) { + VectorSchemaRoot root = streamIterator.getRoot(); int rowCount = root.getRowCount(); - totalRows+= rowCount; + totalRows += rowCount; System.out.println("Record Batch with " + rowCount + " rows:"); // Iterate through rows for (int row = 0; row < rowCount; row++) { - FieldVector docID = root.getVector("docID"); - Float4Vector score = (Float4Vector) root.getVector("score"); - FieldVector shardID = root.getVector("shardID"); - FieldVector nodeID = root.getVector("nodeID"); - - ShardId sid = ShardId.fromString((String) getValue(shardID, row)); - int value = (int) getValue(docID, row); - System.out.println("DocID: " + value + " ShardID" + sid + "NodeID: " + getValue(nodeID, row)); - scoreDocs.add(new ScoreDoc(value, score.get(row), sid.id())); + FieldVector ord = root.getVector("ord"); + FieldVector count = root.getVector("count"); + long countValue = (long) getValue(count, row); + String ordValue = (String) getValue(ord, row); + bucketMap.put(ordValue, bucketMap.getOrDefault(ordValue, 0L) + countValue); } } - TotalHits totalHits = new TotalHits(totalRows, Relation.EQUAL_TO); - return new ReducedQueryPhase( - totalHits, - totalRows, - 1.0f, - false, - false, - null, - null, - null, - new SortedTopDocs(scoreDocs.toArray(ScoreDoc[]::new), false, null, null, null), + // List orders = new ArrayList<>(); + // orders.add(BucketOrder.count(false)); + // aggs.add(new StringTerms( + // "categories", + // InternalOrder.key(true), + // InternalOrder.compound(orders), + // null, + // DocValueFormat.RAW, + // 25, + // false, + // 0L, + // buckets, + // 0, + // new TermsAggregator.BucketCountThresholds(1, 0, 10, 25) + // )); + // InternalAggregations agg = InternalAggregations.reduce(List.of(InternalAggregations.from(aggs)), + // aggReduceContextBuilder.forFinalReduction()); + // finalAggs.add(agg); + } + List orders = new ArrayList<>(); + orders.add(BucketOrder.count(false)); + List buckets = new ArrayList<>(); + bucketMap.entrySet().stream().sorted(Map.Entry.comparingByValue().reversed()).limit(500).forEach(entry -> { + buckets.add( + new StringTerms.Bucket( + new BytesRef(entry.getKey()), + entry.getValue(), + new InternalAggregations(List.of()), + false, + 0, + DocValueFormat.RAW + ) + ); + }); + aggs.add( + new StringTerms( + "categories", + InternalOrder.key(true), + InternalOrder.compound(orders), null, - 1, - totalRows, + DocValueFormat.RAW, + 25, + false, + 0L, + buckets, 0, - totalRows == 0, - list.stream().flatMap(ssr -> ssr.getFlightTickets().stream()).collect(Collectors.toList()) - ); - } catch (IOException e) { - throw new RuntimeException(e); - } finally { - try { - arrowReader.close(); - } catch (IOException e) { - throw new RuntimeException(e); - } - } + new TermsAggregator.BucketCountThresholds(1, 0, 10, 25) + ) + ); + + // InternalAggregations finalReduce = reduceAggs(aggReduceContextBuilder, performFinalReduce, + // List.of(InternalAggregations.from(aggs))); + + return new ReducedQueryPhase( + totalHits, + totalRows, + 1.0f, + false, + false, + null, + InternalAggregations.from(aggs), + null, + new SortedTopDocs(scoreDocs.toArray(ScoreDoc[]::new), false, null, null, null), + null, + 1, + 500, + 0, + totalRows == 0, + list.stream().flatMap(ssr -> ssr.getFlightTickets().stream()).collect(Collectors.toList()) + ); } catch (Exception e) { throw new RuntimeException(e); } } - public ReducedQueryPhase reducedFromStream(List list) { - - - try (SessionContext context = new SessionContext()) { + public ReducedQueryPhase reducedFromStreamInternal( + List list, + InternalAggregation.ReduceContextBuilder aggReduceContextBuilder, + boolean performFinalReduce + ) { + System.out.println("will try to reduce from stream here"); + try { - List tickets = list.stream().flatMap(r -> r.getFlightTickets().stream()) + List tickets = list.stream() + .flatMap(r -> r.getFlightTickets().stream()) .map(OSTicket::getBytes) .collect(Collectors.toList()); - - // execute the query and get a dataframe - CompletableFuture frame = DataFusion.query(tickets); - - DataFrame dataFrame = null; - ArrowReader arrowReader = null; - try { - dataFrame = frame.get(); - arrowReader = dataFrame.collect(new RootAllocator()).get(); - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); - } - int totalRows = 0; List scoreDocs = new ArrayList<>(); - try { - while (arrowReader.loadNextBatch()) { - VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + TotalHits totalHits = new TotalHits(totalRows, Relation.EQUAL_TO); + List aggs = new ArrayList<>(); + Map bucketMap = new HashMap(); + + for (byte[] ticket : tickets) { + StreamIterator streamIterator = streamManager.getStreamIterator(StreamTicket.fromBytes(ticket)); + while (streamIterator.next()) { + VectorSchemaRoot root = streamIterator.getRoot(); int rowCount = root.getRowCount(); - totalRows+= rowCount; + totalRows += rowCount; System.out.println("Record Batch with " + rowCount + " rows:"); // Iterate through rows for (int row = 0; row < rowCount; row++) { - FieldVector docID = root.getVector("docID"); - Float4Vector score = (Float4Vector) root.getVector("score"); - FieldVector shardID = root.getVector("shardID"); - FieldVector nodeID = root.getVector("nodeID"); - - ShardId sid = ShardId.fromString((String) getValue(shardID, row)); - int value = (int) getValue(docID, row); - System.out.println("DocID: " + value + " ShardID" + sid + "NodeID: " + getValue(nodeID, row)); - scoreDocs.add(new ScoreDoc(value, score.get(row), sid.id())); + FieldVector ord = root.getVector("ord"); + FieldVector count = root.getVector("count"); + long countValue = (long) getValue(count, row); + String ordValue = (String) getValue(ord, row); + bucketMap.put(ordValue, bucketMap.getOrDefault(ordValue, 0L) + countValue); } } - TotalHits totalHits = new TotalHits(totalRows, Relation.EQUAL_TO); - return new ReducedQueryPhase( - totalHits, - totalRows, - 1.0f, - false, - false, - null, - null, - null, - new SortedTopDocs(scoreDocs.toArray(ScoreDoc[]::new), false, null, null, null), + // List orders = new ArrayList<>(); + // orders.add(BucketOrder.count(false)); + // aggs.add(new StringTerms( + // "categories", + // InternalOrder.key(true), + // InternalOrder.compound(orders), + // null, + // DocValueFormat.RAW, + // 25, + // false, + // 0L, + // buckets, + // 0, + // new TermsAggregator.BucketCountThresholds(1, 0, 10, 25) + // )); + // InternalAggregations agg = InternalAggregations.reduce(List.of(InternalAggregations.from(aggs)), + // aggReduceContextBuilder.forFinalReduction()); + // finalAggs.add(agg); + } + List orders = new ArrayList<>(); + orders.add(BucketOrder.count(false)); + List buckets = new ArrayList<>(); + bucketMap.forEach((key, value) -> { + buckets.add( + new StringTerms.Bucket(new BytesRef(key), value, new InternalAggregations(List.of()), false, 0, DocValueFormat.RAW) + ); + }); + aggs.add( + new StringTerms( + "categories", + InternalOrder.key(true), + InternalOrder.compound(orders), null, - 1, - totalRows, + DocValueFormat.RAW, + 25, + false, + 0L, + buckets, 0, - totalRows == 0, - list.stream().flatMap(ssr -> ssr.getFlightTickets().stream()).collect(Collectors.toList()) - ); - } catch (IOException e) { - throw new RuntimeException(e); - } finally { - try { - arrowReader.close(); - } catch (IOException e) { - throw new RuntimeException(e); - } - } + new TermsAggregator.BucketCountThresholds(1, 0, 10, 25) + ) + ); + + // InternalAggregations finalReduce = reduceAggs(aggReduceContextBuilder, performFinalReduce, + // List.of(InternalAggregations.from(aggs))); + + return new ReducedQueryPhase( + totalHits, + totalRows, + 1.0f, + false, + false, + null, + InternalAggregations.from(aggs), + null, + new SortedTopDocs(scoreDocs.toArray(ScoreDoc[]::new), false, null, null, null), + null, + 1, + totalRows, + 0, + totalRows == 0, + list.stream().flatMap(ssr -> ssr.getFlightTickets().stream()).collect(Collectors.toList()) + ); } catch (Exception e) { throw new RuntimeException(e); } @@ -858,6 +922,8 @@ private static Object getValue(FieldVector vector, int index) { return new String(((VarCharVector) vector).get(index)); } else if (vector instanceof BitVector) { return ((BitVector) vector).get(index) != 0; + } else if (vector instanceof UInt8Vector) { + return ((UInt8Vector) vector).get(index); } // Add more types as needed diff --git a/server/src/main/java/org/opensearch/action/search/SearchResponse.java b/server/src/main/java/org/opensearch/action/search/SearchResponse.java index e346a6fc84baa..35f0b17dc4d9d 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchResponse.java +++ b/server/src/main/java/org/opensearch/action/search/SearchResponse.java @@ -56,11 +56,9 @@ import org.opensearch.search.SearchExtBuilder; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; -import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.aggregations.Aggregations; import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.internal.InternalSearchResponse; -import org.opensearch.search.internal.ShardStreamQueryResult; import org.opensearch.search.profile.ProfileShardResult; import org.opensearch.search.profile.SearchProfileShardResults; import org.opensearch.search.stream.OSTicket; diff --git a/server/src/main/java/org/opensearch/action/search/SearchResponseSections.java b/server/src/main/java/org/opensearch/action/search/SearchResponseSections.java index b02493893f558..2d713c4c3dc67 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchResponseSections.java +++ b/server/src/main/java/org/opensearch/action/search/SearchResponseSections.java @@ -39,10 +39,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.search.SearchExtBuilder; import org.opensearch.search.SearchHits; -import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.aggregations.Aggregations; -import org.opensearch.search.aggregations.InternalAggregations; -import org.opensearch.search.internal.ShardStreamQueryResult; import org.opensearch.search.profile.ProfileShardResult; import org.opensearch.search.profile.SearchProfileShardResults; import org.opensearch.search.stream.OSTicket; diff --git a/server/src/main/java/org/opensearch/action/search/SearchTransportService.java b/server/src/main/java/org/opensearch/action/search/SearchTransportService.java index f26f8c2a75555..8c30d37076cc9 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchTransportService.java +++ b/server/src/main/java/org/opensearch/action/search/SearchTransportService.java @@ -60,7 +60,6 @@ import org.opensearch.search.query.QuerySearchRequest; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.ScrollQuerySearchResult; -import org.opensearch.search.query.StreamQueryResponse; import org.opensearch.search.stream.StreamSearchResult; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.RemoteClusterService; @@ -244,7 +243,6 @@ public void sendExecuteQuery( // we optimize this and expect a QueryFetchSearchResult if we only have a single shard in the search request // this used to be the QUERY_AND_FETCH which doesn't exist anymore. - if (request.isStreamRequest()) { Writeable.Reader reader = StreamSearchResult::new; final ActionListener handler = responseWrapper.apply(connection, listener); diff --git a/server/src/main/java/org/opensearch/action/search/SearchType.java b/server/src/main/java/org/opensearch/action/search/SearchType.java index a8ada789adf22..a8e75c5f89113 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchType.java +++ b/server/src/main/java/org/opensearch/action/search/SearchType.java @@ -89,7 +89,7 @@ public static SearchType fromId(byte id) { } else if (id == 1 || id == 3) { // TODO this bwc layer can be removed once this is back-ported to 5.3 QUERY_AND_FETCH is removed // now return QUERY_THEN_FETCH; - } else if (id == 5) { + } else if (id == 5) { return STREAM; } else { throw new IllegalArgumentException("No search type for [" + id + "]"); diff --git a/server/src/main/java/org/opensearch/action/search/StreamAsyncAction.java b/server/src/main/java/org/opensearch/action/search/StreamAsyncAction.java index 00da8c2706a92..e9670ac1ed348 100644 --- a/server/src/main/java/org/opensearch/action/search/StreamAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/StreamAsyncAction.java @@ -32,7 +32,6 @@ package org.opensearch.action.search; -import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.VectorSchemaRoot; @@ -40,26 +39,15 @@ import org.apache.logging.log4j.Logger; import org.opensearch.arrow.StreamIterator; import org.opensearch.arrow.StreamManager; -import org.opensearch.arrow.StreamProducer; import org.opensearch.arrow.StreamTicket; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.routing.GroupShardsIterator; import org.opensearch.common.util.concurrent.AbstractRunnable; -import org.opensearch.common.util.concurrent.AtomicArray; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; -import org.opensearch.datafusion.DataFrame; -import org.opensearch.datafusion.DataFrameStreamProducer; -import org.opensearch.datafusion.DataFusion; -import org.opensearch.datafusion.RecordBatchStream; -import org.opensearch.datafusion.SessionContext; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchPhaseResult; -import org.opensearch.search.SearchShardTarget; import org.opensearch.search.internal.AliasFilter; import org.opensearch.search.internal.InternalSearchResponse; -import org.opensearch.search.internal.ShardSearchContextId; -import org.opensearch.search.internal.ShardStreamQueryResult; import org.opensearch.search.stream.OSTicket; import org.opensearch.search.stream.StreamSearchResult; import org.opensearch.telemetry.tracing.Tracer; @@ -84,15 +72,53 @@ class StreamAsyncAction extends SearchQueryThenFetchAsyncAction { public static Logger logger = LogManager.getLogger(StreamAsyncAction.class); private final SearchPhaseController searchPhaseController; - public StreamAsyncAction(Logger logger, SearchTransportService searchTransportService, BiFunction nodeIdToConnection, Map aliasFilter, Map concreteIndexBoosts, Map> indexRoutings, SearchPhaseController searchPhaseController, Executor executor, QueryPhaseResultConsumer resultConsumer, SearchRequest request, ActionListener listener, GroupShardsIterator shardsIts, TransportSearchAction.SearchTimeProvider timeProvider, ClusterState clusterState, SearchTask task, SearchResponse.Clusters clusters, SearchRequestContext searchRequestContext, Tracer tracer) { - super(logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings, searchPhaseController, executor, resultConsumer, request, listener, shardsIts, timeProvider, clusterState, task, clusters, searchRequestContext, tracer); + public StreamAsyncAction( + Logger logger, + SearchTransportService searchTransportService, + BiFunction nodeIdToConnection, + Map aliasFilter, + Map concreteIndexBoosts, + Map> indexRoutings, + SearchPhaseController searchPhaseController, + Executor executor, + QueryPhaseResultConsumer resultConsumer, + SearchRequest request, + ActionListener listener, + GroupShardsIterator shardsIts, + TransportSearchAction.SearchTimeProvider timeProvider, + ClusterState clusterState, + SearchTask task, + SearchResponse.Clusters clusters, + SearchRequestContext searchRequestContext, + Tracer tracer + ) { + super( + logger, + searchTransportService, + nodeIdToConnection, + aliasFilter, + concreteIndexBoosts, + indexRoutings, + searchPhaseController, + executor, + resultConsumer, + request, + listener, + shardsIts, + timeProvider, + clusterState, + task, + clusters, + searchRequestContext, + tracer + ); this.searchPhaseController = searchPhaseController; } -// @Override -// protected SearchPhase getNextPhase(final SearchPhaseResults results, SearchPhaseContext context) { -// return new StreamSearchReducePhase("stream_reduce", context); -// } + // @Override + // protected SearchPhase getNextPhase(final SearchPhaseResults results, SearchPhaseContext context) { + // return new StreamSearchReducePhase("stream_reduce", context); + // } class StreamSearchReducePhase extends SearchPhase { private SearchPhaseContext context; @@ -122,40 +148,54 @@ protected void doRun() throws Exception { // fetch all the tickets (one byte[] per shard) and hand that off to Datafusion.Query // this creates a single stream that we'll register with the streammanager on this coordinator. List results = StreamAsyncAction.this.results.getAtomicArray().asList(); -// List tickets = results.stream().flatMap(r -> ((StreamSearchResult) r).getFlightTickets().stream()) -// .map(OSTicket::getBytes) -// .collect(Collectors.toList()); + // List tickets = results.stream().flatMap(r -> ((StreamSearchResult) r).getFlightTickets().stream()) + // .map(OSTicket::getBytes) + // .collect(Collectors.toList()); - List tickets = results.stream().flatMap(r -> ((StreamSearchResult) r).getFlightTickets().stream()) + List tickets = results.stream() + .flatMap(r -> ((StreamSearchResult) r).getFlightTickets().stream()) .collect(Collectors.toList()); // This is additional metadata for the fetch phase that will be conducted on the coordinator - // StreamTargetResponse is a wrapper for an individual shard that contains the contextId and ShardTarget that served the original + // StreamTargetResponse is a wrapper for an individual shard that contains the contextId and ShardTarget that served the + // original // query phase so we can fetch from it. - List targets = StreamAsyncAction.this.results.getAtomicArray().asList() + List targets = StreamAsyncAction.this.results.getAtomicArray() + .asList() .stream() .map(r -> new StreamTargetResponse(r.queryResult(), r.getSearchShardTarget())) .collect(Collectors.toList()); -// StreamManager streamManager = searchPhaseController.getStreamManager(); -// StreamIterator streamIterator = streamManager.getStreamIterator(StreamTicket.fromBytes(tickets.get(0))); -// List hits = new ArrayList<>(); -// while (streamIterator.next()) { -// VectorSchemaRoot root = streamIterator.getRoot(); -// int rowCount = root.getRowCount(); -// // Iterate through rows -// for (int row = 0; row < rowCount; row++) { -// FieldVector ord = root.getVector("ord"); -// FieldVector count = root.getVector("count");; -// -// -// int ordVal = (int) getValue(ord, row); -// int countVal = (int) getValue(count, row); -// logger.info("ORD {} COUNT {}", ordVal, countVal); -// } -// } -// StreamTicket streamTicket = streamManager.registerStream(DataFrameStreamProducer.query(tickets)); - InternalSearchResponse internalSearchResponse = new InternalSearchResponse(SearchHits.empty(), null, null, null, false, false, 1, Collections.emptyList(), List.of(tickets.get(0)), targets); + StreamManager streamManager = searchPhaseController.getStreamManager(); + StreamIterator streamIterator = streamManager.getStreamIterator(StreamTicket.fromBytes(tickets.get(0).getBytes())); + List hits = new ArrayList<>(); + while (streamIterator.next()) { + VectorSchemaRoot root = streamIterator.getRoot(); + int rowCount = root.getRowCount(); + // Iterate through rows + for (int row = 0; row < rowCount; row++) { + FieldVector ord = root.getVector("ord"); + FieldVector count = root.getVector("count"); + ; + + int ordVal = (int) getValue(ord, row); + int countVal = (int) getValue(count, row); + logger.info("ORD {} COUNT {}", ordVal, countVal); + } + } + // StreamTicket streamTicket = streamManager.registerStream(DataFrameStreamProducer.query(tickets)); + InternalSearchResponse internalSearchResponse = new InternalSearchResponse( + SearchHits.empty(), + null, + null, + null, + false, + false, + 1, + Collections.emptyList(), + List.of(tickets.get(0)), + targets + ); context.sendSearchResponse(internalSearchResponse, StreamAsyncAction.this.results.getAtomicArray()); } catch (Exception e) { logger.error("broken", e); diff --git a/server/src/main/java/org/opensearch/action/search/StreamTargetResponse.java b/server/src/main/java/org/opensearch/action/search/StreamTargetResponse.java index 3948106e52fba..f64e2785b93b2 100644 --- a/server/src/main/java/org/opensearch/action/search/StreamTargetResponse.java +++ b/server/src/main/java/org/opensearch/action/search/StreamTargetResponse.java @@ -49,9 +49,6 @@ public void writeTo(StreamOutput out) throws IOException { @Override public String toString() { - return "StreamTargetResponse{" + - "querySearchResult=" + querySearchResult + - ", searchShardTarget=" + searchShardTarget + - '}'; + return "StreamTargetResponse{" + "querySearchResult=" + querySearchResult + ", searchShardTarget=" + searchShardTarget + '}'; } } diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index 5df862dd78017..3213704d4d4ec 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -124,7 +124,8 @@ import java.util.stream.StreamSupport; import static org.opensearch.action.admin.cluster.node.tasks.get.GetTaskAction.TASKS_ORIGIN; -import static org.opensearch.action.search.SearchType.*; +import static org.opensearch.action.search.SearchType.DFS_QUERY_THEN_FETCH; +import static org.opensearch.action.search.SearchType.QUERY_THEN_FETCH; import static org.opensearch.search.sort.FieldSortBuilder.hasPrimaryFieldSort; /** diff --git a/server/src/main/java/org/opensearch/action/search/TransportStreamedJoinAction.java b/server/src/main/java/org/opensearch/action/search/TransportStreamedJoinAction.java index 5793ce978871b..ccfed6d94eaa7 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportStreamedJoinAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportStreamedJoinAction.java @@ -96,35 +96,36 @@ public TransportStreamedJoinAction( */ @Override protected void doExecute(Task task, JoinRequest request, ActionListener listener) { - GroupedActionListener groupedListener = new GroupedActionListener<>( - new ActionListener<>() { - @Override - public void onResponse(Collection collection) { - List responses = new ArrayList<>(collection); - StreamTicket streamTicket = streamManager.registerStream(DataFrameStreamProducer.join( - tickets(responses.get(0)), - tickets(responses.get(1)), - request.getJoinField() - )); - if (request.isGetHits()) { - getHits(task, request, responses, streamTicket, listener); - } else { - listener.onResponse(new JoinResponse(new OSTicket(streamTicket.getTicketID(), streamTicket.getNodeID()))); - } + GroupedActionListener groupedListener = new GroupedActionListener<>(new ActionListener<>() { + @Override + public void onResponse(Collection collection) { + List responses = new ArrayList<>(collection); + StreamTicket streamTicket = streamManager.registerStream( + DataFrameStreamProducer.join(tickets(responses.get(0)), tickets(responses.get(1)), request.getJoinField()) + ); + if (request.isGetHits()) { + getHits(task, request, responses, streamTicket, listener); + } else { + listener.onResponse(new JoinResponse(new OSTicket(streamTicket.getTicketID(), streamTicket.getNodeID()))); } + } - @Override - public void onFailure(Exception e) { - listener.onFailure(new RuntimeException(e)); - } - }, - 2 - ); + @Override + public void onFailure(Exception e) { + listener.onFailure(new RuntimeException(e)); + } + }, 2); searchIndex(task, request.getLeftIndex(), groupedListener); searchIndex(task, request.getRightIndex(), groupedListener); } - private void getHits(Task task, JoinRequest request, List responses, StreamTicket ticket, ActionListener listener) { + private void getHits( + Task task, + JoinRequest request, + List responses, + StreamTicket ticket, + ActionListener listener + ) { String leftIndex = request.getLeftIndex().indices()[0]; String rightIndex = request.getRightIndex().indices()[0]; @@ -168,29 +169,51 @@ private void getHits(Task task, JoinRequest request, List respon // for each hit split by index Map>> hitsByIndex = organizeHitsByIndexAndShard(hits); - GroupedActionListener groupedListener = new GroupedActionListener<>( - new ActionListener<>() { - @Override - public void onResponse(Collection maps) { - - Map> left = getHitsForIndex(leftIndex, maps); - Map> right = getHitsForIndex(rightIndex, maps); - SearchHit[] searchHits = hits.stream() - .map(hit -> mergeSource(left.get(hit.joinValue), right.get(hit.joinValue), request.getLeftAlias(), request.getRightAlias())) - .toArray(SearchHit[]::new); - listener.onResponse(new JoinResponse(new SearchHits(searchHits, new TotalHits(searchHits.length, TotalHits.Relation.EQUAL_TO), 1.0f))); - } + GroupedActionListener groupedListener = new GroupedActionListener<>(new ActionListener<>() { + @Override + public void onResponse(Collection maps) { + + Map> left = getHitsForIndex(leftIndex, maps); + Map> right = getHitsForIndex(rightIndex, maps); + SearchHit[] searchHits = hits.stream() + .map( + hit -> mergeSource( + left.get(hit.joinValue), + right.get(hit.joinValue), + request.getLeftAlias(), + request.getRightAlias() + ) + ) + .toArray(SearchHit[]::new); + listener.onResponse( + new JoinResponse(new SearchHits(searchHits, new TotalHits(searchHits.length, TotalHits.Relation.EQUAL_TO), 1.0f)) + ); + } - @Override - public void onFailure(Exception e) { - listener.onFailure(new RuntimeException(e)); - } - }, - 2 - ); + @Override + public void onFailure(Exception e) { + listener.onFailure(new RuntimeException(e)); + } + }, 2); assert hitsByIndex.size() == 2; - getIndexResults(hitsByIndex.get(leftIndex), targets, task, request.getLeftIndex(), request.getJoinField(), leftIndex, groupedListener); - getIndexResults(hitsByIndex.get(rightIndex), targets, task, request.getRightIndex(), request.getJoinField(), rightIndex, groupedListener); + getIndexResults( + hitsByIndex.get(leftIndex), + targets, + task, + request.getLeftIndex(), + request.getJoinField(), + leftIndex, + groupedListener + ); + getIndexResults( + hitsByIndex.get(rightIndex), + targets, + task, + request.getRightIndex(), + request.getJoinField(), + rightIndex, + groupedListener + ); } @@ -214,14 +237,13 @@ private SearchHit mergeSource(List l, List r, String leftA Map documentFields = new HashMap<>(); Map metaFields = new HashMap<>(); - left - .getFields() + left.getFields() .forEach( - (fieldName, docField) -> - (MapperService.META_FIELDS_BEFORE_7DOT8.contains(fieldName) - ? metaFields - : documentFields) - .put(fieldName, docField)); + (fieldName, docField) -> (MapperService.META_FIELDS_BEFORE_7DOT8.contains(fieldName) ? metaFields : documentFields).put( + fieldName, + docField + ) + ); String combinedId = left.getId() + "|" + right.getId(); SearchHit searchHit = new SearchHit(left.docId(), combinedId, documentFields, metaFields); searchHit.sourceRef(left.getSourceRef()); @@ -232,22 +254,28 @@ private SearchHit mergeSource(List l, List r, String leftA } private static Map prefixColNames(String prefix, SearchHit hit) { - return hit.getSourceAsMap().entrySet().stream() - .collect(Collectors.toMap( - entry -> prefix.concat(entry.getKey()), - Map.Entry::getValue, - (v1, v2) -> v2, // In case of duplicate keys, keep the last value - HashMap::new // Use HashMap as the map implementation - )); + return hit.getSourceAsMap() + .entrySet() + .stream() + .collect( + Collectors.toMap( + entry -> prefix.concat(entry.getKey()), + Map.Entry::getValue, + (v1, v2) -> v2, // In case of duplicate keys, keep the last value + HashMap::new // Use HashMap as the map implementation + ) + ); } - private void getIndexResults(Map> hitsPerShard, - Map targets, - Task task, - SearchRequest req, - String joinField, - String indexName, - ActionListener listener) { + private void getIndexResults( + Map> hitsPerShard, + Map targets, + Task task, + SearchRequest req, + String joinField, + String indexName, + ActionListener listener + ) { int count = (int) hitsPerShard.values().stream().mapToInt(List::size).count(); GroupedActionListener l = new GroupedActionListener<>(new ActionListener<>() { @Override @@ -269,22 +297,30 @@ public void onFailure(Exception e) { Transport.Connection connection = transportService.getConnection(node); ShardFetchSearchRequest fetchRequest = createFetchRequest( searchPhaseResult.getQuerySearchResult().getContextId(), - entry.getValue().stream().map(h -> h.leftShardId.getIndexName().equals(shardId.getIndexName()) ? h.leftDocId : h.rightDocId).collect(Collectors.toList()), + entry.getValue() + .stream() + .map(h -> h.leftShardId.getIndexName().equals(shardId.getIndexName()) ? h.leftDocId : h.rightDocId) + .collect(Collectors.toList()), searchShardTarget.getOriginalIndices(), searchPhaseResult.getQuerySearchResult().getShardSearchRequest(), searchPhaseResult.getQuerySearchResult().getRescoreDocIds() ); - searchTransportService.sendExecuteFetch(connection, fetchRequest, createSearchTask(task, req), new SearchActionListener<>(searchShardTarget, shardId.id()) { - @Override - protected void innerOnResponse(FetchSearchResult result) { - l.onResponse(result); - } + searchTransportService.sendExecuteFetch( + connection, + fetchRequest, + createSearchTask(task, req), + new SearchActionListener<>(searchShardTarget, shardId.id()) { + @Override + protected void innerOnResponse(FetchSearchResult result) { + l.onResponse(result); + } - @Override - public void onFailure(Exception e) { - l.onFailure(e); + @Override + public void onFailure(Exception e) { + l.onFailure(e); + } } - }); + ); } } @@ -298,9 +334,7 @@ private void reduceFetchResults( Map> hitsByJoinField = fetchResults.stream() .filter(result -> result.hits() != null) .flatMap(result -> Arrays.stream(result.hits().getHits())) - .collect(Collectors.groupingBy( - hit -> hit.getSourceAsMap().get(joinField) - )); + .collect(Collectors.groupingBy(hit -> hit.getSourceAsMap().get(joinField))); listener.onResponse(new HitsPerIndex(indexName, hitsByJoinField)); } catch (Exception e) { listener.onFailure(e); @@ -324,15 +358,7 @@ protected ShardFetchSearchRequest createFetchRequest( ShardSearchRequest shardSearchRequest, RescoreDocIds rescoreDocIds ) { - return new ShardFetchSearchRequest( - originalIndices, - contextId, - shardSearchRequest, - entry, - null, - rescoreDocIds, - null - ); + return new ShardFetchSearchRequest(originalIndices, contextId, shardSearchRequest, entry, null, rescoreDocIds, null); } static class Hit { @@ -356,23 +382,21 @@ public static Map>> organizeHitsByIndexAndShard(L // Group hits by left index and shards Map>> leftGroups = hits.stream() - .collect(Collectors.groupingBy( - hit -> hit.leftShardId.getIndex().getName(), + .collect( Collectors.groupingBy( - hit -> hit.leftShardId, - Collectors.toList() + hit -> hit.leftShardId.getIndex().getName(), + Collectors.groupingBy(hit -> hit.leftShardId, Collectors.toList()) ) - )); + ); // Group hits by right index and shards Map>> rightGroups = hits.stream() - .collect(Collectors.groupingBy( - hit -> hit.rightShardId.getIndex().getName(), + .collect( Collectors.groupingBy( - hit -> hit.rightShardId, - Collectors.toList() + hit -> hit.rightShardId.getIndex().getName(), + Collectors.groupingBy(hit -> hit.rightShardId, Collectors.toList()) ) - )); + ); // Combine both maps result.putAll(leftGroups); @@ -381,7 +405,6 @@ public static Map>> organizeHitsByIndexAndShard(L return result; } - private static Object getValue(FieldVector vector, int index) { if (vector == null || vector.isNull(index)) { return "null"; @@ -408,25 +431,14 @@ private static Object getValue(FieldVector vector, int index) { private void searchIndex(Task task, SearchRequest request, GroupedActionListener groupedListener) { SearchRequest leftRequest = request.searchType(SearchType.STREAM); SearchTask leftTask = createSearchTask(task, leftRequest); - transportSearchAction.doExecute( - leftTask, - leftRequest, - groupedListener - ); + transportSearchAction.doExecute(leftTask, leftRequest, groupedListener); } private static SearchTask createSearchTask(Task task, SearchRequest request) { - return request.createTask( - task.getId(), - task.getType(), - task.getAction(), - task.getParentTaskId(), - Collections.emptyMap() - ); + return request.createTask(task.getId(), task.getType(), task.getAction(), task.getParentTaskId(), Collections.emptyMap()); } List tickets(SearchResponse response) { - return Objects.requireNonNull(response.getTickets()).stream() - .map(OSTicket::getBytes).collect(Collectors.toList()); + return Objects.requireNonNull(response.getTickets()).stream().map(OSTicket::getBytes).collect(Collectors.toList()); } } diff --git a/server/src/main/java/org/opensearch/cluster/node/DiscoveryNode.java b/server/src/main/java/org/opensearch/cluster/node/DiscoveryNode.java index 082b32792ec9e..8c9a37a767ede 100644 --- a/server/src/main/java/org/opensearch/cluster/node/DiscoveryNode.java +++ b/server/src/main/java/org/opensearch/cluster/node/DiscoveryNode.java @@ -45,7 +45,6 @@ import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.node.Node; -import org.opensearch.transport.TransportSettings; import java.io.IOException; import java.util.Collections; diff --git a/server/src/main/java/org/opensearch/common/util/FeatureFlags.java b/server/src/main/java/org/opensearch/common/util/FeatureFlags.java index 235537240cfc5..4f0462f0b5cdd 100644 --- a/server/src/main/java/org/opensearch/common/util/FeatureFlags.java +++ b/server/src/main/java/org/opensearch/common/util/FeatureFlags.java @@ -129,11 +129,7 @@ public class FeatureFlags { ); public static final String ARROW_STREAMS = "opensearch.experimental.feature.arrow.streams.enabled"; - public static final Setting ARROW_STREAMS_SETTING = Setting.boolSetting( - ARROW_STREAMS, - true, - Property.NodeScope - ); + public static final Setting ARROW_STREAMS_SETTING = Setting.boolSetting(ARROW_STREAMS, true, Property.NodeScope); private static final List> ALL_FEATURE_FLAG_SETTINGS = List.of( REMOTE_STORE_MIGRATION_EXPERIMENTAL_SETTING, diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 1c318afc71692..cd4807612c9d2 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -314,7 +314,9 @@ import java.util.stream.Stream; import static java.util.stream.Collectors.toList; -import static org.opensearch.common.util.FeatureFlags.*; +import static org.opensearch.common.util.FeatureFlags.ARROW_STREAMS_SETTING; +import static org.opensearch.common.util.FeatureFlags.BACKGROUND_TASK_EXECUTION_EXPERIMENTAL; +import static org.opensearch.common.util.FeatureFlags.TELEMETRY; import static org.opensearch.env.NodeEnvironment.collectFileCacheDataPath; import static org.opensearch.index.ShardIndexingPressureSettings.SHARD_INDEXING_PRESSURE_ENABLED_ATTRIBUTE_KEY; import static org.opensearch.indices.RemoteStoreSettings.CLUSTER_REMOTE_STORE_PINNED_TIMESTAMP_ENABLED; @@ -1364,12 +1366,13 @@ protected Node( throw new IllegalStateException( String.format( Locale.ROOT, - "Only one StreamManagerPlugin can be installed. Found: %d", streamManagerPlugins.size() + "Only one StreamManagerPlugin can be installed. Found: %d", + streamManagerPlugins.size() ) ); } - if(!streamManagerPlugins.isEmpty()) { + if (!streamManagerPlugins.isEmpty()) { streamManager = streamManagerPlugins.get(0).getStreamManager(); logger.info("StreamManager initialized"); } @@ -1463,7 +1466,13 @@ protected Node( b.bind(SearchService.class).toInstance(searchService); b.bind(SearchTransportService.class).toInstance(searchTransportService); b.bind(SearchPhaseController.class) - .toInstance(new SearchPhaseController(namedWriteableRegistry, searchService::aggReduceContextBuilder, searchService.getStreamManager())); + .toInstance( + new SearchPhaseController( + namedWriteableRegistry, + searchService::aggReduceContextBuilder, + searchService.getStreamManager() + ) + ); b.bind(Transport.class).toInstance(transport); b.bind(TransportService.class).toInstance(transportService); b.bind(NetworkService.class).toInstance(networkService); diff --git a/server/src/main/java/org/opensearch/search/DefaultSearchContext.java b/server/src/main/java/org/opensearch/search/DefaultSearchContext.java index 0757f125eb216..d4966635422ea 100644 --- a/server/src/main/java/org/opensearch/search/DefaultSearchContext.java +++ b/server/src/main/java/org/opensearch/search/DefaultSearchContext.java @@ -146,7 +146,7 @@ final class DefaultSearchContext extends SearchContext { private final IndexShard indexShard; private final ClusterService clusterService; private final IndexService indexService; - private final StreamManager streamManager; + private final StreamManager streamManager; private final ContextIndexSearcher searcher; private final DfsSearchResult dfsResult; private final QuerySearchResult queryResult; diff --git a/server/src/main/java/org/opensearch/search/SearchService.java b/server/src/main/java/org/opensearch/search/SearchService.java index 889ca26bb9705..98fe726f15066 100644 --- a/server/src/main/java/org/opensearch/search/SearchService.java +++ b/server/src/main/java/org/opensearch/search/SearchService.java @@ -585,7 +585,7 @@ protected ReaderContext removeReaderContext(long id) { } @Override - protected void doStart() { } + protected void doStart() {} @Override protected void doStop() { @@ -802,7 +802,7 @@ private StreamSearchResult executeStreamPhase(ShardSearchRequest request, Search try (SearchOperationListenerExecutor executor = new SearchOperationListenerExecutor(context)) { loadOrExecuteStreamPhase(request, context); if (context.queryResult().hasSearchContext() == false && readerContext.singleSession()) { -// freeReaderContext(readerContext.id()); + // freeReaderContext(readerContext.id()); } } return context.streamSearchResult(); @@ -869,7 +869,6 @@ public void executeQueryPhase( }, wrapFailureListener(listener, readerContext, markAsUsed)); } - public void executeStreamPhase(QuerySearchRequest request, SearchShardTask task, ActionListener listener) { final ReaderContext readerContext = findReaderContext(request.contextId(), request.shardSearchRequest()); final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.shardSearchRequest()); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/BucketsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/BucketsAggregator.java index 4054e5c9486f2..092a3cb39f650 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/BucketsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/BucketsAggregator.java @@ -109,7 +109,11 @@ public final void grow(long maxBucketOrd) { @Override public void reset() { - docCounts = bigArrays.newLongArray(1, true); + try (LongArray oldArray = docCounts) { + // Create new array after releasing the old one + docCounts.fill(0, docCounts.size(), 0); + // docCounts = bigArrays.newLongArray(1, true); + } } /** diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java index 99b7a34157d98..c5a30e581d789 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java @@ -231,18 +231,18 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol SortedSetDocValues globalOrds = valuesSource.globalOrdinalsValues(ctx); collectionStrategy.globalOrdsReady(globalOrds); - if (collectionStrategy instanceof DenseGlobalOrds - && this.resultStrategy instanceof StandardTermsResults - && sub == LeafBucketCollector.NO_OP_COLLECTOR) { - LeafBucketCollector termDocFreqCollector = termDocFreqCollector( - ctx, - globalOrds, - (ord, docCount) -> incrementBucketDocCount(collectionStrategy.globalOrdToBucketOrd(0, ord), docCount) - ); - if (termDocFreqCollector != null) { - return termDocFreqCollector; - } - } + // if (collectionStrategy instanceof DenseGlobalOrds + // && this.resultStrategy instanceof StandardTermsResults + // && sub == LeafBucketCollector.NO_OP_COLLECTOR) { + // LeafBucketCollector termDocFreqCollector = termDocFreqCollector( + // ctx, + // globalOrds, + // (ord, docCount) -> incrementBucketDocCount(collectionStrategy.globalOrdToBucketOrd(0, ord), docCount) + // ); + // if (termDocFreqCollector != null) { + // return termDocFreqCollector; + // } + // } SortedDocValues singleValues = DocValues.unwrapSingleton(globalOrds); if (singleValues != null) { diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalTerms.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalTerms.java index b8f9406ff55b9..6bbbc9a15978a 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalTerms.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/InternalTerms.java @@ -340,7 +340,7 @@ protected boolean lessThan(IteratorAndCurrent a, IteratorAndCurrent b) { B lastBucket = null; while (pq.size() > 0) { final IteratorAndCurrent top = pq.top(); - assert lastBucket == null || cmp.compare(top.current(), lastBucket) >= 0; + // assert lastBucket == null || cmp.compare(top.current(), lastBucket) >= 0; if (lastBucket != null && cmp.compare(top.current(), lastBucket) != 0) { // the key changes, reduce what we already buffered and reset the buffer for current buckets final B reduced = reduceBucket(currentBuckets, reduceContext); @@ -351,7 +351,7 @@ protected boolean lessThan(IteratorAndCurrent a, IteratorAndCurrent b) { currentBuckets.add(top.current()); if (top.hasNext()) { top.next(); - assert cmp.compare(top.current(), lastBucket) > 0 : "shards must return data sorted by key"; + // assert cmp.compare(top.current(), lastBucket) > 0 : "shards must return data sorted by key"; pq.updateTop(); } else { pq.pop(); diff --git a/server/src/main/java/org/opensearch/search/aggregations/support/StreamingAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/support/StreamingAggregator.java index 8fe647ea48e16..91cf7054d8eb6 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/support/StreamingAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/support/StreamingAggregator.java @@ -9,12 +9,10 @@ package org.opensearch.search.aggregations.support; import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.Float4Vector; -import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.UInt8Vector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.search.Collector; import org.apache.lucene.search.FilterCollector; import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.Scorable; @@ -23,11 +21,8 @@ import org.opensearch.search.aggregations.Aggregation; import org.opensearch.search.aggregations.Aggregations; import org.opensearch.search.aggregations.Aggregator; -import org.opensearch.search.aggregations.BucketCollectorProcessor; import org.opensearch.search.aggregations.InternalAggregation; -import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.aggregations.LeafBucketCollector; -import org.opensearch.search.aggregations.LeafBucketCollectorBase; import org.opensearch.search.aggregations.bucket.terms.InternalMappedTerms; import org.opensearch.search.aggregations.bucket.terms.InternalTerms; import org.opensearch.search.internal.SearchContext; @@ -36,7 +31,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.Supplier; public class StreamingAggregator extends FilterCollector { @@ -46,6 +40,7 @@ public class StreamingAggregator extends FilterCollector { private final StreamProducer.FlushSignal flushSignal; private final int batchSize; private final ShardId shardId; + /** * Sole constructor. * @@ -74,9 +69,9 @@ public LeafCollector getLeafCollector(LeafReaderContext context) throws IOExcept Map vectors = new HashMap<>(); vectors.put("ord", root.getVector("ord")); vectors.put("count", root.getVector("count")); - final int[] currentRow = {0}; + final int[] currentRow = { 0 }; return new LeafBucketCollector() { - + final LeafBucketCollector leaf = aggregator.getLeafCollector(context); @Override public void setScorer(Scorable scorer) throws IOException { @@ -85,57 +80,52 @@ public void setScorer(Scorable scorer) throws IOException { @Override public void collect(int doc, long owningBucketOrd) throws IOException { - final LeafBucketCollector leaf = aggregator.getLeafCollector(context); leaf.collect(doc); currentRow[0]++; - if (currentRow[0] == batchSize) { + if (currentRow[0] >= batchSize) { flushBatch(); } - - // hit batch size - - // flush } private void flushBatch() throws IOException { - InternalAggregation agg = aggregator.buildAggregations(new long[]{0})[0]; + int bucketCount = 0; + InternalAggregation agg = aggregator.buildAggregations(new long[] { 0 })[0]; if (agg instanceof InternalMappedTerms) { - InternalMappedTerms terms = (InternalMappedTerms) agg; + InternalMappedTerms terms = (InternalMappedTerms) agg; List buckets = terms.getBuckets(); - for (InternalTerms.Bucket bucket : buckets) { + for (int i = 0; i < buckets.size(); i++) { // Get key/value info - String key = bucket.getKeyAsString(); - long docCount = bucket.getDocCount(); + String key = buckets.get(i).getKeyAsString(); + long docCount = buckets.get(i).getDocCount(); - Aggregations aggregations = bucket.getAggregations(); + Aggregations aggregations = buckets.get(i).getAggregations(); for (Aggregation aggregation : aggregations) { - // TODO: subs + // TODO: subs } - // Write to vector storage - // e.g., for term and count vectors: -// VarCharVector keyVector = (VarCharVector) vectors.get("key"); -// keyVector.setSafe(i, key.getBytes()); FieldVector termVector = vectors.get("ord"); FieldVector countVector = vectors.get("count"); - ((VarCharVector) termVector).setSafe(0, key.getBytes()); - ((Float4Vector) countVector).setSafe(0, docCount); - - // Add the values... + ((VarCharVector) termVector).setSafe(i, key.getBytes()); + ((UInt8Vector) countVector).setSafe(i, docCount); + bucketCount++; } aggregator.reset(); - - // Also access high-level statistics -// long otherDocCount = terms.getSumOfOtherDocCounts(); -// long docCountError = terms.getDocCountError(); } + System.out.println("Row count " + bucketCount); // Reset for next batch + root.setRowCount(bucketCount); + flushSignal.awaitConsumption(100000); currentRow[0] = 0; - root.setRowCount(currentRow[0]); - flushSignal.awaitConsumption(1000); + } + + @Override + public void finish() throws IOException { + if (currentRow[0] > 0) { + flushBatch(); + } } }; } diff --git a/server/src/main/java/org/opensearch/search/internal/InternalSearchResponse.java b/server/src/main/java/org/opensearch/search/internal/InternalSearchResponse.java index efa1c13c4550b..faa6f6c6d8d38 100644 --- a/server/src/main/java/org/opensearch/search/internal/InternalSearchResponse.java +++ b/server/src/main/java/org/opensearch/search/internal/InternalSearchResponse.java @@ -42,7 +42,6 @@ import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.search.SearchExtBuilder; import org.opensearch.search.SearchHits; -import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.search.profile.SearchProfileShardResults; import org.opensearch.search.stream.OSTicket; @@ -104,7 +103,18 @@ public InternalSearchResponse( List tickets, List result ) { - super(hits, aggregations, suggest, timedOut, terminatedEarly, profileResults, numReducePhases, searchExtBuilderList, tickets, result); + super( + hits, + aggregations, + suggest, + timedOut, + terminatedEarly, + profileResults, + numReducePhases, + searchExtBuilderList, + tickets, + result + ); } public InternalSearchResponse(StreamInput in) throws IOException { @@ -117,8 +127,8 @@ public InternalSearchResponse(StreamInput in) throws IOException { in.readOptionalWriteable(SearchProfileShardResults::new), in.readVInt(), readSearchExtBuildersOnOrAfter(in), - (in.readBoolean()? in.readList(OSTicket::new): null), - (in.readBoolean()? in.readList(StreamTargetResponse::new): null) + (in.readBoolean() ? in.readList(OSTicket::new) : null), + (in.readBoolean() ? in.readList(StreamTargetResponse::new) : null) ); } diff --git a/server/src/main/java/org/opensearch/search/internal/ShardStreamQueryResult.java b/server/src/main/java/org/opensearch/search/internal/ShardStreamQueryResult.java index 42028511af562..2860e65d42428 100644 --- a/server/src/main/java/org/opensearch/search/internal/ShardStreamQueryResult.java +++ b/server/src/main/java/org/opensearch/search/internal/ShardStreamQueryResult.java @@ -8,7 +8,6 @@ package org.opensearch.search.internal; -import org.opensearch.action.search.SearchContextId; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; diff --git a/server/src/main/java/org/opensearch/search/lookup/SearchLookup.java b/server/src/main/java/org/opensearch/search/lookup/SearchLookup.java index 0a53e30ce2ac3..6d8e3330bc042 100644 --- a/server/src/main/java/org/opensearch/search/lookup/SearchLookup.java +++ b/server/src/main/java/org/opensearch/search/lookup/SearchLookup.java @@ -51,7 +51,7 @@ * @opensearch.api */ @PublicApi(since = "1.0.0") -public class /**/SearchLookup { +public class /**/ SearchLookup { /** * The maximum depth of field dependencies. * When a runtime field's doc values depends on another runtime field's doc values, diff --git a/server/src/main/java/org/opensearch/search/query/StreamQueryResponse.java b/server/src/main/java/org/opensearch/search/query/StreamQueryResponse.java index 12f46e0be483e..e0ee23ad28801 100644 --- a/server/src/main/java/org/opensearch/search/query/StreamQueryResponse.java +++ b/server/src/main/java/org/opensearch/search/query/StreamQueryResponse.java @@ -9,7 +9,6 @@ package org.opensearch.search.query; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.transport.TransportResponse; import org.opensearch.search.SearchPhaseResult; import java.io.IOException; diff --git a/server/src/main/java/org/opensearch/search/query/StreamSearchPhase.java b/server/src/main/java/org/opensearch/search/query/StreamSearchPhase.java index 78026b6392e1f..92a385f4c947c 100644 --- a/server/src/main/java/org/opensearch/search/query/StreamSearchPhase.java +++ b/server/src/main/java/org/opensearch/search/query/StreamSearchPhase.java @@ -10,45 +10,32 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.lucene.search.Collector; import org.apache.lucene.search.Query; import org.opensearch.arrow.StreamManager; import org.opensearch.arrow.StreamProducer; import org.opensearch.arrow.StreamTicket; -import org.opensearch.index.mapper.MappedFieldType; -import org.opensearch.index.query.QueryShardContext; import org.opensearch.search.SearchContextSourcePrinter; import org.opensearch.search.aggregations.AggregationProcessor; import org.opensearch.search.aggregations.Aggregator; import org.opensearch.search.aggregations.support.StreamingAggregator; -import org.opensearch.search.fetch.subphase.FieldAndFormat; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.profile.ProfileShardResult; import org.opensearch.search.profile.SearchProfileShardResults; import org.opensearch.search.stream.OSTicket; import org.opensearch.search.stream.StreamSearchResult; -import org.opensearch.search.stream.collector.ArrowCollector; -import org.opensearch.search.stream.collector.ArrowDocIdCollector; -import org.opensearch.search.stream.collector.ArrowFieldAdaptor; import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; -import java.util.function.Supplier; - -import static org.opensearch.search.stream.collector.ArrowFieldAdaptor.getArrowType; /** * Produce stream from a shard search @@ -72,8 +59,6 @@ public void execute(SearchContext searchContext) throws QueryPhaseExecutionExcep aggregationProcessor.preProcess(searchContext); executeInternal(searchContext, this.getQueryPhaseSearcher()); -// aggregationProcessor.postProcess(searchContext); - if (searchContext.getProfilers() != null) { ProfileShardResult shardResults = SearchProfileShardResults.buildShardResults( searchContext.getProfilers(), @@ -100,21 +85,6 @@ public boolean searchWith( return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } -// @Override -// public AggregationProcessor aggregationProcessor(SearchContext searchContext) { -// return new AggregationProcessor() { -// @Override -// public void preProcess(SearchContext context) { -// -// } -// -// @Override -// public void postProcess(SearchContext context) { -// -// } -// }; -// } - protected boolean searchWithCollector( SearchContext searchContext, ContextIndexSearcher searcher, @@ -134,26 +104,7 @@ private boolean searchWithCollector( boolean timeoutSet ) { -// List fields = searchContext.fetchFieldsContext().fields(); -// -// // map from OpenSearch field to Arrow Field type -// List arrowFieldAdaptors = new ArrayList<>(); -// fields.forEach(field -> { -// System.out.println("field: " + field.field); -// QueryShardContext shardContext = searchContext.getQueryShardContext(); -// MappedFieldType fieldType = shardContext.fieldMapper(field.field); -// ArrowType arrowType = getArrowType(fieldType.typeName()); -// arrowFieldAdaptors.add(new ArrowFieldAdaptor(field.field, arrowType, fieldType.typeName())); -// }); - QuerySearchResult queryResult = searchContext.queryResult(); -// try { -// Collector collector = QueryCollectorContext.createQueryCollector(collectors); -// System.out.println(collector); -// } catch (IOException e) { -// throw new RuntimeException(e); -// } - StreamManager streamManager = searchContext.streamManager(); if (streamManager == null) { throw new RuntimeException("StreamManager not setup"); @@ -162,17 +113,20 @@ private boolean searchWithCollector( @Override public BatchedJob createJob(BufferAllocator allocator) { return new BatchedJob() { - - @Override public void run(VectorSchemaRoot root, StreamProducer.FlushSignal flushSignal) { try { - final StreamingAggregator arrowDocIdCollector = new StreamingAggregator((Aggregator) QueryCollectorContext.createQueryCollector(collectors), searchContext, root, 1, flushSignal, searchContext.shardTarget().getShardId()); + final StreamingAggregator arrowDocIdCollector = new StreamingAggregator( + (Aggregator) QueryCollectorContext.createQueryCollector(collectors), + searchContext, + root, + 1_000_000, + flushSignal, + searchContext.shardTarget().getShardId() + ); try { searcher.search(query, arrowDocIdCollector); } catch (EarlyTerminatingCollector.EarlyTerminationException e) { - // EarlyTerminationException is not caught in ContextIndexSearcher to allow force termination of collection. Postcollection - // still needs to be processed for Aggregations when early termination takes place. searchContext.bucketCollectorProcessor().processPostCollection(arrowDocIdCollector); queryResult.terminatedEarly(true); } @@ -183,7 +137,8 @@ public void run(VectorSchemaRoot root, StreamProducer.FlushSignal flushSignal) { } queryResult.searchTimedOut(true); } - if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) { + if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER + && queryResult.terminatedEarly() == null) { queryResult.terminatedEarly(false); } @@ -205,21 +160,9 @@ public void onCancel() { @Override public VectorSchemaRoot createRoot(BufferAllocator allocator) { Map arrowFields = new HashMap<>(); - -// Field docIdField = new Field("ord", FieldType.notNullable(new ArrowType.Int(32, true)), null); -// arrowFields.put("ord", docIdField); - Field scoreField = new Field( - "count", - FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), - null - ); - arrowFields.put("count", scoreField); - -// arrowFieldAdaptors.forEach(field -> { -// Field arrowField = new Field(field.getFieldName(), FieldType.nullable(field.getArrowType()), null); -// arrowFields.put(field.getFieldName(), arrowField); -// }); - arrowFields.put("ord", new Field("ord", FieldType.notNullable(new ArrowType.Utf8()), null)); + Field countField = new Field("count", FieldType.nullable(new ArrowType.Int(64, false)), null); + arrowFields.put("count", countField); + arrowFields.put("ord", new Field("ord", FieldType.nullable(new ArrowType.Utf8()), null)); Schema schema = new Schema(arrowFields.values()); return VectorSchemaRoot.create(schema, allocator); } diff --git a/server/src/main/java/org/opensearch/search/stream/OSTicket.java b/server/src/main/java/org/opensearch/search/stream/OSTicket.java index 024cd169f748c..d37dbc1a6d9da 100644 --- a/server/src/main/java/org/opensearch/search/stream/OSTicket.java +++ b/server/src/main/java/org/opensearch/search/stream/OSTicket.java @@ -50,9 +50,6 @@ public void writeTo(StreamOutput out) throws IOException { @Override public String toString() { - return "OSTicket{" + - "ticketID='" + streamTicket.getTicketID() + '\'' + - ", nodeID='" + streamTicket.getNodeID() + '\'' + - '}'; + return "OSTicket{" + "ticketID='" + streamTicket.getTicketID() + '\'' + ", nodeID='" + streamTicket.getNodeID() + '\'' + '}'; } } diff --git a/server/src/main/java/org/opensearch/search/stream/StreamSearchResult.java b/server/src/main/java/org/opensearch/search/stream/StreamSearchResult.java index 8e9ae3c9b29d7..5f20a5f705256 100644 --- a/server/src/main/java/org/opensearch/search/stream/StreamSearchResult.java +++ b/server/src/main/java/org/opensearch/search/stream/StreamSearchResult.java @@ -9,11 +9,11 @@ package org.opensearch.search.stream; import org.opensearch.common.annotation.ExperimentalApi; - import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.fetch.FetchSearchResult; import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.search.query.QuerySearchResult; @@ -25,10 +25,12 @@ public class StreamSearchResult extends SearchPhaseResult { private List flightTickets; private final QuerySearchResult queryResult; + private final FetchSearchResult fetchResult; public StreamSearchResult() { super(); this.queryResult = QuerySearchResult.nullInstance(); + this.fetchResult = new FetchSearchResult(); } public StreamSearchResult(StreamInput in) throws IOException { @@ -39,12 +41,14 @@ public StreamSearchResult(StreamInput in) throws IOException { flightTickets = in.readList(OSTicket::new); } queryResult = new QuerySearchResult(contextId, getSearchShardTarget(), getShardSearchRequest()); + fetchResult = new FetchSearchResult(contextId, getSearchShardTarget()); setSearchShardTarget(getSearchShardTarget()); } public StreamSearchResult(ShardSearchContextId id, SearchShardTarget shardTarget, ShardSearchRequest searchRequest) { this.contextId = id; queryResult = new QuerySearchResult(id, shardTarget, searchRequest); + fetchResult = new FetchSearchResult(id, shardTarget); setSearchShardTarget(shardTarget); setShardSearchRequest(searchRequest); } @@ -67,7 +71,12 @@ public void setShardIndex(int shardIndex) { @Override public QuerySearchResult queryResult() { - return queryResult; + return queryResult; + } + + @Override + public FetchSearchResult fetchResult() { + return fetchResult; } public List getFlightTickets() { diff --git a/server/src/main/java/org/opensearch/search/stream/collector/ArrowCollector.java b/server/src/main/java/org/opensearch/search/stream/collector/ArrowCollector.java index 9087dcd4043c8..ffbf1cae9db41 100644 --- a/server/src/main/java/org/opensearch/search/stream/collector/ArrowCollector.java +++ b/server/src/main/java/org/opensearch/search/stream/collector/ArrowCollector.java @@ -81,7 +81,7 @@ public LeafCollector getLeafCollector(LeafReaderContext context) throws IOExcept } }); - final int[] currentRow = {0}; + final int[] currentRow = { 0 }; return new LeafCollector() { private final int[] docIds = new int[batchSize]; private final float[] scores = new float[batchSize]; @@ -163,7 +163,7 @@ public void setScorer(Scorable scorable) throws IOException { } } - ; + ; } @Override diff --git a/server/src/main/java/org/opensearch/search/stream/collector/ArrowDocIdCollector.java b/server/src/main/java/org/opensearch/search/stream/collector/ArrowDocIdCollector.java index 7879daf0563a9..47f7b7b0c57be 100644 --- a/server/src/main/java/org/opensearch/search/stream/collector/ArrowDocIdCollector.java +++ b/server/src/main/java/org/opensearch/search/stream/collector/ArrowDocIdCollector.java @@ -7,11 +7,11 @@ */ package org.opensearch.search.stream.collector; + import org.apache.arrow.vector.Float4Vector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.holders.VarCharHolder; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Collector; import org.apache.lucene.search.FilterCollector; @@ -35,7 +35,13 @@ public class ArrowDocIdCollector extends FilterCollector { private int currentRow; private SearchShardTarget target; - public ArrowDocIdCollector(Collector in, VectorSchemaRoot root, StreamProducer.FlushSignal flushSignal, int batchSize, SearchShardTarget target) { + public ArrowDocIdCollector( + Collector in, + VectorSchemaRoot root, + StreamProducer.FlushSignal flushSignal, + int batchSize, + SearchShardTarget target + ) { super(in); this.root = root; this.docIDVector = (IntVector) root.getVector("docID"); @@ -59,10 +65,9 @@ public ScoreMode scoreMode() { return ScoreMode.TOP_DOCS; } - @Override public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { - LeafCollector inner = (this.in == null ? null: super.getLeafCollector(context)); + LeafCollector inner = (this.in == null ? null : super.getLeafCollector(context)); return new LeafCollector() { private Scorable scorer; diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index cab92d1200664..cc2f289530c3c 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -797,7 +797,8 @@ private SearchDfsQueryThenFetchAsyncAction createSearchDfsQueryThenFetchAsyncAct SearchPhaseController controller = new SearchPhaseController( writableRegistry(), r -> InternalAggregationTestCase.emptyReduceContextBuilder(), - searchService.getStreamManager()); + searchService.getStreamManager() + ); SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); SearchTask task = new SearchTask(0, "n/a", "n/a", () -> "test", null, Collections.emptyMap()); Executor executor = OpenSearchExecutors.newDirectExecutorService(); @@ -852,7 +853,8 @@ private SearchQueryThenFetchAsyncAction createSearchQueryThenFetchAsyncAction( SearchPhaseController controller = new SearchPhaseController( writableRegistry(), r -> InternalAggregationTestCase.emptyReduceContextBuilder(), - searchService.getStreamManager()); + searchService.getStreamManager() + ); SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); SearchTask task = new SearchTask(0, "n/a", "n/a", () -> "test", null, Collections.emptyMap()); Executor executor = OpenSearchExecutors.newDirectExecutorService(); @@ -915,7 +917,8 @@ private FetchSearchPhase createFetchSearchPhase() { SearchPhaseController controller = new SearchPhaseController( writableRegistry(), r -> InternalAggregationTestCase.emptyReduceContextBuilder(), - searchService.getStreamManager()); + searchService.getStreamManager() + ); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1); QueryPhaseResultConsumer results = controller.newSearchPhaseResults( OpenSearchExecutors.newDirectExecutorService(), diff --git a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java index e4e9afa797f42..b5805efad3d7a 100644 --- a/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java +++ b/server/src/test/java/org/opensearch/action/search/CanMatchPreFilterSearchPhaseTests.java @@ -665,7 +665,8 @@ public void sendCanMatch( SearchPhaseController controller = new SearchPhaseController( writableRegistry(), r -> InternalAggregationTestCase.emptyReduceContextBuilder(), - searchService.getStreamManager()); + searchService.getStreamManager() + ); QueryPhaseResultConsumer resultConsumer = new QueryPhaseResultConsumer( searchRequest, diff --git a/server/src/test/java/org/opensearch/action/search/DfsQueryPhaseTests.java b/server/src/test/java/org/opensearch/action/search/DfsQueryPhaseTests.java index de29175787a1e..80f8354d776fa 100644 --- a/server/src/test/java/org/opensearch/action/search/DfsQueryPhaseTests.java +++ b/server/src/test/java/org/opensearch/action/search/DfsQueryPhaseTests.java @@ -327,6 +327,10 @@ public void run() throws IOException { } private SearchPhaseController searchPhaseController() { - return new SearchPhaseController(writableRegistry(), request -> InternalAggregationTestCase.emptyReduceContextBuilder(), searchService.getStreamManager()); + return new SearchPhaseController( + writableRegistry(), + request -> InternalAggregationTestCase.emptyReduceContextBuilder(), + searchService.getStreamManager() + ); } } diff --git a/server/src/test/java/org/opensearch/action/search/FetchSearchPhaseTests.java b/server/src/test/java/org/opensearch/action/search/FetchSearchPhaseTests.java index 92a3f01f160f0..fd77c6ad69554 100644 --- a/server/src/test/java/org/opensearch/action/search/FetchSearchPhaseTests.java +++ b/server/src/test/java/org/opensearch/action/search/FetchSearchPhaseTests.java @@ -64,7 +64,8 @@ public void testShortcutQueryAndFetchOptimization() { SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder(), - searchService.getStreamManager()); + searchService.getStreamManager() + ); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1); QueryPhaseResultConsumer results = controller.newSearchPhaseResults( OpenSearchExecutors.newDirectExecutorService(), @@ -126,7 +127,8 @@ public void testFetchTwoDocument() { SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder(), - searchService.getStreamManager()); + searchService.getStreamManager() + ); QueryPhaseResultConsumer results = controller.newSearchPhaseResults( OpenSearchExecutors.newDirectExecutorService(), new NoopCircuitBreaker(CircuitBreaker.REQUEST), @@ -222,7 +224,8 @@ public void testFailFetchOneDoc() { SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder(), - searchService.getStreamManager()); + searchService.getStreamManager() + ); QueryPhaseResultConsumer results = controller.newSearchPhaseResults( OpenSearchExecutors.newDirectExecutorService(), new NoopCircuitBreaker(CircuitBreaker.REQUEST), @@ -319,7 +322,8 @@ public void testFetchDocsConcurrently() throws InterruptedException { SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder(), - searchService.getStreamManager()); + searchService.getStreamManager() + ); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(numHits); QueryPhaseResultConsumer results = controller.newSearchPhaseResults( OpenSearchExecutors.newDirectExecutorService(), @@ -410,7 +414,8 @@ public void testExceptionFailsPhase() { SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder(), - searchService.getStreamManager()); + searchService.getStreamManager() + ); QueryPhaseResultConsumer results = controller.newSearchPhaseResults( OpenSearchExecutors.newDirectExecutorService(), new NoopCircuitBreaker(CircuitBreaker.REQUEST), @@ -502,7 +507,8 @@ public void testCleanupIrrelevantContexts() { // contexts that are not fetched s SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder(), - searchService.getStreamManager()); + searchService.getStreamManager() + ); QueryPhaseResultConsumer results = controller.newSearchPhaseResults( OpenSearchExecutors.newDirectExecutorService(), new NoopCircuitBreaker(CircuitBreaker.REQUEST), diff --git a/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java index 38a49005e5e0e..2af9410d4e4b0 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchQueryThenFetchAsyncActionTests.java @@ -209,7 +209,8 @@ public void sendExecuteQuery( SearchPhaseController controller = new SearchPhaseController( writableRegistry(), r -> InternalAggregationTestCase.emptyReduceContextBuilder(), - searchService.getStreamManager()); + searchService.getStreamManager() + ); SearchTask task = new SearchTask(0, "n/a", "n/a", () -> "test", null, Collections.emptyMap()); QueryPhaseResultConsumer resultConsumer = new QueryPhaseResultConsumer( searchRequest, diff --git a/server/src/test/java/org/opensearch/index/mapper/DerivedFieldMapperQueryTests.java b/server/src/test/java/org/opensearch/index/mapper/DerivedFieldMapperQueryTests.java index 0d5ecbf8134c5..c744f2592e24f 100644 --- a/server/src/test/java/org/opensearch/index/mapper/DerivedFieldMapperQueryTests.java +++ b/server/src/test/java/org/opensearch/index/mapper/DerivedFieldMapperQueryTests.java @@ -435,7 +435,7 @@ public void execute() { query = geoShapeQuery("geopoint", new Rectangle(0.0, 55.0, 55.0, 0.0)).toQuery(queryShardContext); topDocs = searcher.search(query, 10); assertEquals(4, topDocs.totalHits.value); - } + } } } diff --git a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java index f29b90f770e40..2d66248454348 100644 --- a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java +++ b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java @@ -742,7 +742,8 @@ public void testTransformSearchPhase() { SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder(), - searchService.getStreamManager()); + searchService.getStreamManager() + ); SearchPhaseContext searchPhaseContext = new MockSearchPhaseContext(10); QueryPhaseResultConsumer searchPhaseResults = new QueryPhaseResultConsumer( searchPhaseContext.getRequest(), diff --git a/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java b/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java index 34aeb466ae360..954e9b22c3c2a 100644 --- a/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java +++ b/server/src/test/java/org/opensearch/search/query/QueryPhaseTests.java @@ -556,8 +556,8 @@ public void testArrow() throws Exception { final int numDocs = scaledRandomIntBetween(100, 200); for (int i = 0; i < numDocs; ++i) { Document doc = new Document(); - doc.add(new StringField("joinField", Integer.toString(i%10), Store.NO)); - doc.add(new SortedSetDocValuesField("joinField", new BytesRef(Integer.toString(i%10)))); + doc.add(new StringField("joinField", Integer.toString(i % 10), Store.NO)); + doc.add(new SortedSetDocValuesField("joinField", new BytesRef(Integer.toString(i % 10)))); w.addDocument(doc); } w.close(); diff --git a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java index a224145465b03..ce815fc729ccb 100644 --- a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java @@ -2289,7 +2289,8 @@ public void onFailure(final Exception e) { SearchPhaseController searchPhaseController = new SearchPhaseController( writableRegistry(), searchService::aggReduceContextBuilder, - searchService.getStreamManager()); + searchService.getStreamManager() + ); SearchRequestOperationsCompositeListenerFactory searchRequestOperationsCompositeListenerFactory = new SearchRequestOperationsCompositeListenerFactory(); actions.put(