Skip to content

Commit

Permalink
Optimize the Term Aggs using Streams
Browse files Browse the repository at this point in the history
Signed-off-by: Rishabh Maurya <[email protected]>
  • Loading branch information
rishabhmaurya committed Jan 31, 2025
1 parent edabf74 commit 0c62d1f
Show file tree
Hide file tree
Showing 10 changed files with 312 additions and 176 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,6 @@

package org.opensearch.action.search;

import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.UInt8Vector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.FieldDoc;
Expand All @@ -54,10 +45,7 @@
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.TotalHits.Relation;
import org.apache.lucene.search.grouping.CollapseTopFieldDocs;
import org.apache.lucene.util.BytesRef;
import org.opensearch.arrow.spi.StreamManager;
import org.opensearch.arrow.spi.StreamReader;
import org.opensearch.arrow.spi.StreamTicket;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
Expand All @@ -72,6 +60,7 @@
import org.opensearch.search.aggregations.InternalAggregation.ReduceContext;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.aggregations.InternalOrder;
import org.opensearch.search.aggregations.bucket.terms.InternalTerms;
import org.opensearch.search.aggregations.bucket.terms.StringTerms;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregator;
import org.opensearch.search.builder.SearchSourceBuilder;
Expand All @@ -82,7 +71,9 @@
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.profile.ProfileShardResult;
import org.opensearch.search.profile.SearchProfileShardResults;
import org.opensearch.search.query.BatchProcessor;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.query.TicketProcessor;
import org.opensearch.search.sort.SortedWiderNumericSortField;
import org.opensearch.search.stream.OSTicket;
import org.opensearch.search.stream.StreamSearchResult;
Expand All @@ -93,13 +84,13 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.PriorityQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.function.Consumer;
Expand Down Expand Up @@ -711,119 +702,56 @@ static int getTopDocsSize(SearchRequest request) {
: source.from());
}

static class TicketProcessorResult {
private final Map<String, Long> bucketMap;
private final int rowCount;

public TicketProcessorResult(Map<String, Long> bucketMap, int rowCount) {
this.bucketMap = bucketMap;
this.rowCount = rowCount;
}

public Map<String, Long> getBucketMap() {
return bucketMap;
}

public int getRowCount() {
return rowCount;
}
}

static class TicketProcessor implements Callable<TicketProcessorResult> {
private final byte[] ticket;
private final StreamManager streamManager;

public TicketProcessor(byte[] ticket, StreamManager streamManager) {
this.ticket = ticket;
this.streamManager = streamManager;
}

@Override
public TicketProcessorResult call() throws Exception {
Map<String, Long> localBucketMap = new HashMap<>();
int localRowCount = 0;
StreamTicket streamTicket = streamManager.getStreamTicketFactory().fromBytes(ticket);
StreamReader streamIterator = streamManager.getStreamReader(streamTicket);

while (streamIterator.next()) {
VectorSchemaRoot root = streamIterator.getRoot();
int rowCount = root.getRowCount();
localRowCount += rowCount;

for (int row = 0; row < rowCount; row++) {
FieldVector ord = root.getVector("ord");
FieldVector count = root.getVector("count");
long countValue = (long) getValue(count, row);
String ordValue = (String) getValue(ord, row);
localBucketMap.merge(ordValue, countValue, Long::sum);
}
}
return new TicketProcessorResult(localBucketMap, localRowCount);
}
}

public ReducedQueryPhase reducedFromStream(
List<StreamSearchResult> list,
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce,
Executor executor
) {
try {
List<byte[]> tickets = list.stream()
.flatMap(r -> r.getFlightTickets().stream())
.map(OSTicket::getBytes)
.collect(Collectors.toList());

List<CompletableFuture<TicketProcessorResult>> futures = new ArrayList<>();
int totalRows = 0;
Map<String, Long> bucketMap = new ConcurrentHashMap<>();
List<byte[]> tickets = list.stream()
.flatMap(r -> r.getFlightTickets().stream())
.map(OSTicket::getBytes)
.collect(Collectors.toList());

List<ScoreDoc> scoreDocs = new ArrayList<>();
List<InternalAggregation> aggs = new ArrayList<>();
List<CompletableFuture<TicketProcessor.TicketProcessorResult>> producerFutures = new ArrayList<>();
BatchProcessor batchProcessor = new BatchProcessor();
int totalRows = 0;

CompletableFuture<Void> consumerFuture = CompletableFuture.runAsync(() -> {
try {
for (byte[] ticket : tickets) {
CompletableFuture<TicketProcessorResult> future = CompletableFuture.supplyAsync(() -> {
try {
return new TicketProcessor(ticket, streamManager).call();
} catch (Exception e) {
throw new CompletionException(e);
}
}, executor);
futures.add(future);
}

CompletableFuture<Void> allFutures = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));
batchProcessor.processBatches();
} catch (Exception e) {
throw new CompletionException(e);
}
}, executor);

allFutures.join();
for (CompletableFuture<TicketProcessorResult> future : futures) {
TicketProcessorResult result = future.get();
totalRows += result.getRowCount();
result.getBucketMap().forEach((key, value) -> bucketMap.merge(key, value, Long::sum));
}
} catch (InterruptedException | ExecutionException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("Error processing tickets in parallel", e);
try {
for (byte[] ticket : tickets) {
CompletableFuture<TicketProcessor.TicketProcessorResult> future = CompletableFuture.supplyAsync(() -> {
try {
return new TicketProcessor(ticket, streamManager, batchProcessor.getBatchQueue()).call();
} catch (Exception e) {
throw new CompletionException(e);
}
}, executor);
producerFutures.add(future);
}
CompletableFuture<Void> allProducers = CompletableFuture.allOf(producerFutures.toArray(new CompletableFuture[0]));
allProducers.join();
batchProcessor.markProducersComplete();
consumerFuture.join();

for (CompletableFuture<TicketProcessor.TicketProcessorResult> future : producerFutures) {
TicketProcessor.TicketProcessorResult result = future.get();
totalRows += result.getRowCount();
}

TotalHits totalHits = new TotalHits(totalRows, Relation.EQUAL_TO);

List<ScoreDoc> scoreDocs = new ArrayList<>();
List<BucketOrder> orders = new ArrayList<>();
orders.add(BucketOrder.count(false));
List<StringTerms.Bucket> buckets = new ArrayList<>();
bucketMap.entrySet().stream().sorted(Map.Entry.<String, Long>comparingByValue().reversed()).limit(500).forEach(entry -> {
buckets.add(
new StringTerms.Bucket(
new BytesRef(entry.getKey()),
entry.getValue(),
new InternalAggregations(List.of()),
false,
0,
DocValueFormat.RAW
)
);
});

List<InternalAggregation> aggs = new ArrayList<>();
List<StringTerms.Bucket> buckets = getTop500Buckets(batchProcessor.getMergedBuckets());
aggs.add(
new StringTerms(
"categories",
Expand Down Expand Up @@ -857,34 +785,38 @@ public ReducedQueryPhase reducedFromStream(
totalRows == 0,
list.stream().flatMap(ssr -> ssr.getFlightTickets().stream()).collect(Collectors.toList())
);
} catch (Exception e) {
throw new RuntimeException(e);

} catch (InterruptedException | ExecutionException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("Error processing tickets in parallel", e);
}
}

private static Object getValue(FieldVector vector, int index) {
if (vector.isNull(index)) {
return "null";
private List<StringTerms.Bucket> getTop500Buckets(List<StringTerms.Bucket> finalBuckets) {
PriorityQueue<StringTerms.Bucket> top500Queue = new PriorityQueue<>(
500,
Comparator.comparingLong(InternalTerms.Bucket::getDocCount)
);

int i = 0;
for (; i < Math.min(500, finalBuckets.size()); i++) {
top500Queue.offer(finalBuckets.get(i));
}

if (vector instanceof IntVector) {
return ((IntVector) vector).get(index);
} else if (vector instanceof BigIntVector) {
return ((BigIntVector) vector).get(index);
} else if (vector instanceof Float4Vector) {
return ((Float4Vector) vector).get(index);
} else if (vector instanceof Float8Vector) {
return ((Float8Vector) vector).get(index);
} else if (vector instanceof VarCharVector) {
return new String(((VarCharVector) vector).get(index));
} else if (vector instanceof BitVector) {
return ((BitVector) vector).get(index) != 0;
} else if (vector instanceof UInt8Vector) {
return ((UInt8Vector) vector).get(index);
for (; i < finalBuckets.size(); i++) {
StringTerms.Bucket bucket = finalBuckets.get(i);
if (bucket.getDocCount() > top500Queue.peek().getDocCount()) {
top500Queue.poll();
top500Queue.offer(bucket);
}
}
// Add more types as needed

return "Unsupported type: " + vector.getClass().getSimpleName();
ArrayList<StringTerms.Bucket> result = new ArrayList<>(top500Queue.size());
while (!top500Queue.isEmpty()) {
result.addFirst(top500Queue.poll());
}
result.sort((o1, o2) -> Long.compare(o2.getDocCount(), o1.getDocCount()));
return result;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,19 @@ public SearchResponseSections(
List<ProcessorExecutionDetail> processorResult,
List<OSTicket> tickets
) {
this(hits, aggregations, suggest, timedOut, terminatedEarly, profileResults, numReducePhases, searchExtBuilders, processorResult, tickets, null);
this(
hits,
aggregations,
suggest,
timedOut,
terminatedEarly,
profileResults,
numReducePhases,
searchExtBuilders,
processorResult,
tickets,
null
);
}

public SearchResponseSections(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public class GlobalOrdinalsStringTermsAggregator extends AbstractStringTermsAggr

private final LongPredicate acceptedGlobalOrdinals;
private final long valueCount;
private final String fieldName;
public final String fieldName;
private Weight weight;
protected final CollectionStrategy collectionStrategy;
private final SetOnce<SortedSetDocValues> dvs = new SetOnce<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ public void setDocCountError(long docCountError) {
this.docCountError = docCountError;
}

public void setDocCount(long dc) {
docCount = dc;
}

@Override
public void setDocCountError(Function<Long, Long> updater) {
this.docCountError = updater.apply(this.docCountError);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public class StringTerms extends InternalMappedTerms<StringTerms, StringTerms.Bu
* @opensearch.internal
*/
public static class Bucket extends InternalTerms.Bucket<Bucket> {
BytesRef termBytes;
public BytesRef termBytes;

public Bucket(
BytesRef term,
Expand Down
Loading

0 comments on commit 0c62d1f

Please sign in to comment.