diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java index ddc5ee25174..bf638dfcf77 100644 --- a/src/main/java/org/apache/sysds/api/DMLScript.java +++ b/src/main/java/org/apache/sysds/api/DMLScript.java @@ -326,6 +326,7 @@ public static boolean executeScript( Configuration conf, String[] args ) //reset runtime platform and visualize flag setGlobalExecMode(oldrtplatform); EXPLAIN = oldexplain; + CommonThreadPool.shutdownAsyncPools(); } return true; @@ -572,9 +573,6 @@ public static void cleanupHadoopExecution( DMLConfig config ) //0) cleanup federated workers if necessary FederatedData.clearFederatedWorkers(); - //0) shutdown prefetch/broadcast thread pool if necessary - CommonThreadPool.shutdownAsyncPools(); - //1) cleanup scratch space (everything for current uuid) //(required otherwise export to hdfs would skip assumed unnecessary writes if same name) HDFSTool.deleteFileIfExistOnHDFS( config.getTextValue(DMLConfig.SCRATCH_SPACE) + dirSuffix ); diff --git a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java index 62352bd2a01..088545b8ed8 100644 --- a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java +++ b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java @@ -29,7 +29,6 @@ import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.lops.Compression.CompressConfig; import org.apache.sysds.lops.compile.linearization.ILinearize; -import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.runtime.util.CommonThreadPool; @@ -66,7 +65,7 @@ public class ConfigurationManager{ _dmlconf = new DMLConfig(); _cconf = new CompilerConfig(); - final ExecutorService pool = CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism()); + final ExecutorService pool = CommonThreadPool.get(); pool.submit(() ->{ try{ IOUtilFunctions.getFileSystem(_rJob); @@ -75,7 +74,7 @@ public class ConfigurationManager{ LOG.warn(e.getMessage()); } }); - pool.shutdown(); + // pool.shutdown(); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java index 178c13ad297..ffea0193b71 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java @@ -218,7 +218,7 @@ else if(finalCols[c] == null) { } } - final ExecutorService pool = CommonThreadPool.get(Math.max(Math.min(clen / 500, k), 1)); + final ExecutorService pool = CommonThreadPool.get(); try { List finalGroups = pool.submit(() -> { diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java index 94bbaf2545e..790a92de58f 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java @@ -19,11 +19,25 @@ package org.apache.sysds.runtime.controlprogram; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.log4j.Level; import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types.DataType; +import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.FileFormat; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.conf.CompilerConfig; @@ -31,7 +45,6 @@ import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.hops.recompile.Recompiler; import org.apache.sysds.lops.Lop; -import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.parser.DMLProgram; import org.apache.sysds.parser.DataIdentifier; import org.apache.sysds.parser.ParForStatementBlock; @@ -91,24 +104,10 @@ import org.apache.sysds.runtime.lineage.LineageItemUtils; import org.apache.sysds.runtime.meta.DataCharacteristics; import org.apache.sysds.runtime.util.CollectionUtils; -import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.runtime.util.ProgramConverter; import org.apache.sysds.runtime.util.UtilFunctions; import org.apache.sysds.utils.stats.ParForStatistics; -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.IntStream; -import java.util.stream.Stream; - /** @@ -122,7 +121,7 @@ * */ public class ParForProgramBlock extends ForProgramBlock { - protected static final Log LOG = LogFactory.getLog(CommonThreadPool.class.getName()); + protected static final Log LOG = LogFactory.getLog(ParForProgramBlock.class.getName()); // execution modes public enum PExecMode { LOCAL, //local (master) multi-core execution mode diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java index 4db4f2b8b2d..985fdb056e5 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java @@ -614,7 +614,6 @@ public long getMaxIndexInRange(int dim) { */ public void forEachParallel(BiFunction forEachFunction) { ExecutorService pool = CommonThreadPool.get(_fedMap.size()); - ArrayList mappingTasks = new ArrayList<>(); for(Pair fedMap : _fedMap) mappingTasks.add(new MappingTask(fedMap.getKey(), fedMap.getValue(), forEachFunction, _ID)); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java index 5343332eb5f..a3be38cafd2 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java @@ -31,7 +31,6 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysds.runtime.instructions.cp.ListObject; import org.apache.sysds.runtime.util.CommonThreadPool; @@ -91,7 +90,7 @@ private void computeEpoch(long dataSize, int batchIter) { ListObject params = pullModel(); Future accGradients = ConcurrentUtils.constantFuture(null); if(_tpool == null) - _tpool = CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism()); + _tpool = CommonThreadPool.get(); try { for (int j = 0; j < batchIter; j++) { diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java index ed5d48d6b38..94ab8f00de2 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java @@ -49,7 +49,6 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.codegen.CodegenUtils; import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; -import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.frame.data.columns.Array; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory; @@ -860,26 +859,25 @@ private double arraysSizeInMemory() { size += ArrayFactory.getInMemorySize(_schema[j], rlen, true); else {// allocated if(rlen > 1000 && clen > 10 && ConfigurationManager.isParallelIOEnabled()) { - final ExecutorService pool = CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism()); + final ExecutorService pool = CommonThreadPool.get(); try { size += pool.submit(() -> { return Arrays.stream(_coldata).parallel() // parallel columns .map(x -> x.getInMemorySize()).reduce(0L, Long::sum); }).get(); - pool.shutdown(); - } catch(InterruptedException | ExecutionException e) { - pool.shutdown(); LOG.error(e); for(Array aa : _coldata) size += aa.getInMemorySize(); } + finally{ + pool.shutdown(); + } } else { for(Array aa : _coldata) size += aa.getInMemorySize(); - } } return size; diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameReaderJSONLParallel.java b/src/main/java/org/apache/sysds/runtime/io/FrameReaderJSONLParallel.java index 17abd9e3c8a..14143e00999 100644 --- a/src/main/java/org/apache/sysds/runtime/io/FrameReaderJSONLParallel.java +++ b/src/main/java/org/apache/sysds/runtime/io/FrameReaderJSONLParallel.java @@ -53,7 +53,7 @@ protected void readJSONLFrameFromHDFS(Path path, JobConf jobConf, FileSystem fil splits = IOUtilFunctions.sortInputSplits(splits); try{ - ExecutorService executorPool = CommonThreadPool.get(Math.min(numThreads, splits.length)); + ExecutorService executorPool = CommonThreadPool.get(numThreads); //compute num rows per split ArrayList countRowsTasks = new ArrayList<>(); diff --git a/src/main/java/org/apache/sysds/runtime/iogen/FormatIdentifyer.java b/src/main/java/org/apache/sysds/runtime/iogen/FormatIdentifyer.java index aa02ad37fca..3cbc174d64b 100644 --- a/src/main/java/org/apache/sysds/runtime/iogen/FormatIdentifyer.java +++ b/src/main/java/org/apache/sysds/runtime/iogen/FormatIdentifyer.java @@ -647,19 +647,14 @@ private Pair, HashSet> buildValueKeyPattern() { colIndexes.add(0); try { - ExecutorService pool = CommonThreadPool.get(1); ArrayList tasks = new ArrayList<>(); tasks.add( new BuildColsKeyPatternSingleRowTask(prefixesRemovedReverse, prefixesRemoved, prefixes, suffixes, prefixesRemovedReverseSort, keys, colSuffixes, lcs, colIndexes)); - //wait until all tasks have been executed - List> rt = pool.invokeAll(tasks); - pool.shutdown(); - //check for exceptions - for(Future task : rt) - task.get(); + for(Callable task : tasks) + task.call(); } catch(Exception e) { throw new RuntimeException("Failed BuildValueKeyPattern.", e); @@ -770,19 +765,13 @@ private Pair, HashSet> buildIndexKeyPattern(boolean ke colIndexe.add(0); try { - ExecutorService pool = CommonThreadPool.get(1); ArrayList tasks = new ArrayList<>(); tasks.add( new BuildColsKeyPatternSingleRowTask(prefixesRemovedReverse, prefixesRemoved, prefixes, suffixes, prefixesRemovedReverseSort, keys, colSuffixes, lcs, colIndexe)); - - //wait until all tasks have been executed - List> rt = pool.invokeAll(tasks); - pool.shutdown(); - //check for exceptions - for(Future task : rt) - task.get(); + for(Callable task : tasks) + task.call(); } catch(Exception e) { throw new RuntimeException("Failed BuildValueKeyPattern.", e); diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java index 598fef549dd..26a00425a07 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java @@ -670,7 +670,7 @@ private static long execute(ArrayList> tasks, DnnParameters param } } else { - ExecutorService pool = CommonThreadPool.get( Math.min(k, params.N) ); + ExecutorService pool = CommonThreadPool.get(k); List> taskret = pool.invokeAll(tasks); pool.shutdown(); for( Future task : taskret ) diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java index 5aaf0cd46a5..01a5216b4bb 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java @@ -35,8 +35,8 @@ import java.util.concurrent.Future; import java.util.stream.IntStream; -import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.lang3.concurrent.ConcurrentUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -54,7 +54,6 @@ import org.apache.sysds.runtime.compress.lib.CLALibAggTernaryOp; import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType; -import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.DenseBlockFP64; import org.apache.sysds.runtime.data.DenseBlockFactory; @@ -373,7 +372,7 @@ public final MatrixBlock allocateDenseBlock() { } public Future allocateBlockAsync() { - ExecutorService pool = CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism()); + ExecutorService pool = CommonThreadPool.get(); return (pool != null) ? pool.submit(() -> allocateBlock()) : //async ConcurrentUtils.constantFuture(allocateBlock()); //fallback sync } diff --git a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java index cc6483d2588..bc3be9844c9 100644 --- a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java +++ b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java @@ -99,14 +99,14 @@ public static ExecutorService get() { * @param k The number of threads wanted * @return The executor with specified parallelism */ - public static ExecutorService get(int k) { + public synchronized static ExecutorService get(int k) { if(size == k) return shared; else if(Thread.currentThread().getName().equals("main")) { if(shared2 != null && shared2K == k) return shared2; else if(shared2 == null) { - shared2 = new CommonThreadPool(Executors.newFixedThreadPool(k)); + shared2 = new CommonThreadPool(new ForkJoinPool(k)); shared2K = k; return shared2; } @@ -141,12 +141,13 @@ public static void invokeAndShutdown(ExecutorService pool, Collection r : ret) r.get(); - // shutdown pool - pool.shutdown(); } catch(Exception ex) { throw new DMLRuntimeException(ex); } + finally{ + pool.shutdown(); + } } /** @@ -155,8 +156,8 @@ public static void invokeAndShutdown(ExecutorService pool, Collection[] _locks; protected byte[][] _buff; private int _pos; diff --git a/src/test/java/org/apache/sysds/performance/Main.java b/src/test/java/org/apache/sysds/performance/Main.java index 1e51a703bf7..fa89a62b536 100644 --- a/src/test/java/org/apache/sysds/performance/Main.java +++ b/src/test/java/org/apache/sysds/performance/Main.java @@ -123,10 +123,11 @@ private static void run11(String[] args, int id) throws InterruptedException, Ex public static void main(String[] args) { try { exec(Integer.parseInt(args[0]), args); - CommonThreadPool.get().shutdown(); } catch(Exception e) { e.printStackTrace(); + }finally{ + CommonThreadPool.get().shutdown(); } } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java index aa7141bd18c..0b4e193ac4d 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java @@ -19,15 +19,20 @@ package org.apache.sysds.test.functions.federated.multitenant; +import static org.junit.Assert.fail; + import java.io.IOException; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Arrays; - -import static org.junit.Assert.fail; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types.ExecMode; import org.apache.sysds.test.AutomatedTestBase; @@ -36,6 +41,8 @@ import com.google.crypto.tink.subtle.Random; public abstract class MultiTenantTestBase extends AutomatedTestBase { + protected static final Log LOG = LogFactory.getLog(MultiTenantTestBase.class.getName()); + protected ArrayList workerProcesses = new ArrayList<>(); protected ArrayList coordinatorProcesses = new ArrayList<>(); @@ -56,8 +63,7 @@ protected int[] startFedWorkers(int numFedWorkers) { } /** - * Start numFedWorkers federated worker processes on available ports and add - * them to the workerProcesses + * Start numFedWorkers federated worker processes on available ports and add them to the workerProcesses * * @param numFedWorkers the number of federated workers to start * @return int[] the ports of the created federated workers @@ -67,20 +73,20 @@ protected int[] startFedWorkers(int numFedWorkers, String[] addArgs) { for(int counter = 0; counter < numFedWorkers; counter++) { ports[counter] = getRandomAvailablePort(); // start process but only wait long for last one. - Process tmpProcess = startLocalFedWorker(ports[counter], addArgs, - counter == numFedWorkers-1 ? (FED_WORKER_WAIT + Random.randInt(1000)) * 3 : FED_WORKER_WAIT_S); + Process tmpProcess = startLocalFedWorker(ports[counter], addArgs, + counter == numFedWorkers - 1 ? (FED_WORKER_WAIT + Random.randInt(1000)) * 3 : FED_WORKER_WAIT_S); workerProcesses.add(tmpProcess); } return ports; } /** - * Start a coordinator process running the specified script with given arguments - * and add it to the coordinatorProcesses + * Start a coordinator process running the specified script with given arguments and add it to the + * coordinatorProcesses * - * @param execMode the execution mode of the coordinator + * @param execMode the execution mode of the coordinator * @param scriptPath the path to the dml script - * @param args the program arguments for running the dml script + * @param args the program arguments for running the dml script */ protected void startCoordinator(ExecMode execMode, String scriptPath, String[] args) { String separator = System.getProperty("file.separator"); @@ -90,14 +96,14 @@ protected void startCoordinator(ExecMode execMode, String scriptPath, String[] a String em = null; switch(execMode) { case SINGLE_NODE: - em = "singlenode"; - break; + em = "singlenode"; + break; case HYBRID: - em = "hybrid"; - break; + em = "hybrid"; + break; case SPARK: - em = "spark"; - break; + em = "spark"; + break; } ArrayList argsList = new ArrayList<>(); @@ -108,13 +114,14 @@ protected void startCoordinator(ExecMode execMode, String scriptPath, String[] a argsList.addAll(Arrays.asList(args)); // create the processBuilder and redirect the stderr to its stdout - ProcessBuilder processBuilder = new ProcessBuilder(ArrayUtils.addAll(new String[]{ - path, "-cp", classpath, DMLScript.class.getName()}, argsList.toArray(new String[0]))); + ProcessBuilder processBuilder = new ProcessBuilder(ArrayUtils + .addAll(new String[] {path, "-cp", classpath, DMLScript.class.getName()}, argsList.toArray(new String[0]))); Process process = null; try { process = processBuilder.start(); - } catch(IOException ioe) { + } + catch(IOException ioe) { ioe.printStackTrace(); fail("Can't start the coordinator process."); } @@ -122,12 +129,28 @@ protected void startCoordinator(ExecMode execMode, String scriptPath, String[] a } /** - * Wait for all processes of coordinatorProcesses to terminate and collect - * their output + * Wait for all processes of coordinatorProcesses to terminate and collect their output * * @return String the collected output of the coordinator processes */ protected String waitForCoordinators() { + return waitForCoordinators(500); + } + + protected String waitForCoordinators(int timeout){ + ExecutorService executor = Executors.newCachedThreadPool(); + try{ + return executor.submit(() -> waitForCoordinatorsActual()).get(timeout, TimeUnit.SECONDS); + } + catch(Exception e){ + throw new RuntimeException(e); + } + finally{ + executor.shutdown(); + } + } + + private String waitForCoordinatorsActual(){ // wait for the coordinator processes to finish and collect their output StringBuilder outputLog = new StringBuilder(); for(int counter = 0; counter < coordinatorProcesses.size(); counter++) { @@ -139,9 +162,10 @@ protected String waitForCoordinators() { outputLog.append(IOUtils.toString(coord.getErrorStream(), Charset.defaultCharset())); coord.waitFor(); - } catch(Exception ex) { + } + catch(Exception ex) { fail(ex.getClass().getSimpleName() + " thrown while collecting log output of coordinator #" - + Integer.toString(counter+1) + ".\n"); + + Integer.toString(counter + 1) + ".\n"); ex.printStackTrace(); } } diff --git a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSyntaxTest.java b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSyntaxTest.java index f3804066e3c..7d4b5f4b257 100644 --- a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSyntaxTest.java +++ b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSyntaxTest.java @@ -93,9 +93,11 @@ public void testParamservMinimumVersion() { private void runDMLTest(String testname, boolean exceptionExpected, Class exceptionClass, String errmsg) { TestConfiguration config = getTestConfiguration(testname); + setOutputBuffering(true); loadTestConfiguration(config); programArgs = new String[] { "-explain" }; fullDMLScriptName = HOME + testname + ".dml"; runTest(true, exceptionExpected, exceptionClass, errmsg, -1); + setOutputBuffering(false); } }