diff --git a/avro-builder/builder-spi/build.gradle b/avro-builder/builder-spi/build.gradle index d66b3306..f38867fc 100644 --- a/avro-builder/builder-spi/build.gradle +++ b/avro-builder/builder-spi/build.gradle @@ -17,6 +17,7 @@ dependencies { implementation "org.apache.logging.log4j:log4j-api:2.17.1" implementation "commons-io:commons-io:2.11.0" implementation "jakarta.json:jakarta.json-api:2.0.1" + implementation "com.pivovarit:parallel-collectors:2.5.0" testImplementation "org.apache.avro:avro:1.9.2" } diff --git a/avro-builder/builder-spi/src/main/java/com/linkedin/avroutil1/builder/plugins/BuilderPluginContext.java b/avro-builder/builder-spi/src/main/java/com/linkedin/avroutil1/builder/plugins/BuilderPluginContext.java index df9dcf93..8c96be21 100644 --- a/avro-builder/builder-spi/src/main/java/com/linkedin/avroutil1/builder/plugins/BuilderPluginContext.java +++ b/avro-builder/builder-spi/src/main/java/com/linkedin/avroutil1/builder/plugins/BuilderPluginContext.java @@ -8,8 +8,11 @@ import com.linkedin.avroutil1.builder.operations.Operation; import com.linkedin.avroutil1.builder.operations.OperationContext; +import com.linkedin.avroutil1.builder.util.StreamUtil; import java.util.ArrayList; import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** @@ -17,6 +20,8 @@ */ public class BuilderPluginContext { + private static final Logger LOGGER = LoggerFactory.getLogger(BuilderPluginContext.class); + private final List operations = new ArrayList<>(1); private volatile boolean sealed = false; private final OperationContext operationContext; @@ -43,12 +48,16 @@ public void run() throws Exception { //"seal" any internal state to prevent plugins from trying to do weird things during execution sealed = true; - operations.parallelStream().forEach(op -> { + int operationCount = operations.stream().collect(StreamUtil.toParallelStream(op -> { try { op.run(operationContext); } catch (Exception e) { throw new IllegalStateException("Exception running operation", e); } - }); + + return 1; + }, 2)).reduce(0, Integer::sum); + + LOGGER.info("Executed {} operations for builder plugins", operationCount); } } diff --git a/avro-builder/builder-spi/src/main/java/com/linkedin/avroutil1/builder/util/StreamUtil.java b/avro-builder/builder-spi/src/main/java/com/linkedin/avroutil1/builder/util/StreamUtil.java new file mode 100644 index 00000000..5d2fa523 --- /dev/null +++ b/avro-builder/builder-spi/src/main/java/com/linkedin/avroutil1/builder/util/StreamUtil.java @@ -0,0 +1,53 @@ +/* + * Copyright 2024 LinkedIn Corp. + * Licensed under the BSD 2-Clause License (the "License"). + * See License in the project root for license information. + */ + +package com.linkedin.avroutil1.builder.util; + +import com.pivovarit.collectors.ParallelCollectors; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.stream.Collector; +import java.util.stream.Stream; + + +/** + * Utilities for dealing with java streams. + */ +public final class StreamUtil { + + /** + * An (effectively) unbounded {@link ExecutorService} used for parallel processing. This is kept unbounded to avoid + * deadlocks caused when using {@link #toParallelStream(Function, int)} recursively. Callers are supposed to set + * sane values for parallelism to avoid spawning a crazy number of concurrent threads. + */ + private static final ExecutorService WORK_EXECUTOR = + new ThreadPoolExecutor(0, Integer.MAX_VALUE, 60, TimeUnit.SECONDS, new SynchronousQueue<>()); + + private StreamUtil() { + // Disallow external instantiation. + } + + /** + * A convenience {@link Collector} used for executing parallel computations on a custom {@link Executor} + * and returning a {@link Stream} instance returning results as they arrive. + *

+ * For the parallelism of 1, the stream is executed by the calling thread. + * + * @param mapper a transformation to be performed in parallel + * @param parallelism the max parallelism level + * @param the type of the collected elements + * @param the result returned by {@code mapper} + * + * @return a {@code Collector} which collects all processed elements into a {@code Stream} in parallel. + */ + public static Collector> toParallelStream(Function mapper, int parallelism) { + return ParallelCollectors.parallelToStream(mapper, WORK_EXECUTOR, parallelism); + } +} diff --git a/avro-builder/builder/src/main/java/com/linkedin/avroutil1/builder/operations/codegen/own/AvroUtilCodeGenPlugin.java b/avro-builder/builder/src/main/java/com/linkedin/avroutil1/builder/operations/codegen/own/AvroUtilCodeGenPlugin.java index 2a9d3dd4..fad2cc65 100644 --- a/avro-builder/builder/src/main/java/com/linkedin/avroutil1/builder/operations/codegen/own/AvroUtilCodeGenPlugin.java +++ b/avro-builder/builder/src/main/java/com/linkedin/avroutil1/builder/operations/codegen/own/AvroUtilCodeGenPlugin.java @@ -9,6 +9,7 @@ import com.linkedin.avroutil1.builder.operations.OperationContext; import com.linkedin.avroutil1.builder.operations.SchemaSet; import com.linkedin.avroutil1.builder.operations.codegen.CodeGenOpConfig; +import com.linkedin.avroutil1.builder.util.StreamUtil; import com.linkedin.avroutil1.builder.plugins.BuilderPlugin; import com.linkedin.avroutil1.builder.plugins.BuilderPluginContext; import com.linkedin.avroutil1.codegen.SpecificRecordClassGenerator; @@ -101,14 +102,14 @@ private void generateCode(OperationContext opContext) { AvroJavaStringRepresentation.fromJson(config.getMethodStringRepresentation().toString()), config.getMinAvroVersion(), config.isUtf8EncodingPutByIndexEnabled()); final SpecificRecordClassGenerator generator = new SpecificRecordClassGenerator(); - List generatedClasses = allNamedSchemas.parallelStream().map(namedSchema -> { + List generatedClasses = allNamedSchemas.stream().collect(StreamUtil.toParallelStream(namedSchema -> { try { // Top level schema return generator.generateSpecificClass(namedSchema, generationConfig); } catch (Exception e) { throw new RuntimeException("failed to generate class for " + namedSchema.getFullName(), e); } - }).collect(Collectors.toList()); + }, 10)).collect(Collectors.toList()); long genEnd = System.currentTimeMillis(); LOGGER.info("Generated {} java source files in {} millis", generatedClasses.size(), genEnd - genStart); @@ -129,7 +130,7 @@ private void writeJavaFilesToDisk(Collection javaFiles, Path outputFol long writeStart = System.currentTimeMillis(); // write out the files we generated - int filesWritten = javaFiles.parallelStream().map(javaFile -> { + int filesWritten = javaFiles.stream().collect(StreamUtil.toParallelStream(javaFile -> { try { javaFile.writeToPath(outputFolderPath); } catch (Exception e) { @@ -137,7 +138,7 @@ private void writeJavaFilesToDisk(Collection javaFiles, Path outputFol } return 1; - }).reduce(0, Integer::sum); + }, 10)).reduce(0, Integer::sum); long writeEnd = System.currentTimeMillis(); LOGGER.info("Wrote out {} generated java source files under {} in {} millis", filesWritten, outputFolderPath,