Skip to content

Commit

Permalink
[MINOR] Fix ThreadPool for Federated
Browse files Browse the repository at this point in the history
This commit goes through the instances that call CommonThreadPool, and
fixes the remaining issues. The new double buffering is unfortunately not
one of them so i changed it to use a static single thread extra.

Closes #1877
  • Loading branch information
Baunsgaard committed Aug 9, 2023
1 parent 6dacde7 commit 649fc8e
Show file tree
Hide file tree
Showing 16 changed files with 94 additions and 86 deletions.
4 changes: 1 addition & 3 deletions src/main/java/org/apache/sysds/api/DMLScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ public static boolean executeScript( Configuration conf, String[] args )
//reset runtime platform and visualize flag
setGlobalExecMode(oldrtplatform);
EXPLAIN = oldexplain;
CommonThreadPool.shutdownAsyncPools();
}

return true;
Expand Down Expand Up @@ -572,9 +573,6 @@ public static void cleanupHadoopExecution( DMLConfig config )
//0) cleanup federated workers if necessary
FederatedData.clearFederatedWorkers();

//0) shutdown prefetch/broadcast thread pool if necessary
CommonThreadPool.shutdownAsyncPools();

//1) cleanup scratch space (everything for current uuid)
//(required otherwise export to hdfs would skip assumed unnecessary writes if same name)
HDFSTool.deleteFileIfExistOnHDFS( config.getTextValue(DMLConfig.SCRATCH_SPACE) + dirSuffix );
Expand Down
5 changes: 2 additions & 3 deletions src/main/java/org/apache/sysds/conf/ConfigurationManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Compression.CompressConfig;
import org.apache.sysds.lops.compile.linearization.ILinearize;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.util.CommonThreadPool;

Expand Down Expand Up @@ -66,7 +65,7 @@ public class ConfigurationManager{
_dmlconf = new DMLConfig();
_cconf = new CompilerConfig();

final ExecutorService pool = CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism());
final ExecutorService pool = CommonThreadPool.get();
pool.submit(() ->{
try{
IOUtilFunctions.getFileSystem(_rJob);
Expand All @@ -75,7 +74,7 @@ public class ConfigurationManager{
LOG.warn(e.getMessage());
}
});
pool.shutdown();
// pool.shutdown();
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ else if(finalCols[c] == null) {
}
}

final ExecutorService pool = CommonThreadPool.get(Math.max(Math.min(clen / 500, k), 1));
final ExecutorService pool = CommonThreadPool.get();
try {

List<AColGroup> finalGroups = pool.submit(() -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,32 @@

package org.apache.sysds.runtime.controlprogram;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.log4j.Level;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.CompilerConfig;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.ParForStatementBlock;
Expand Down Expand Up @@ -91,24 +104,10 @@
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.ProgramConverter;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.stats.ParForStatistics;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;



/**
Expand All @@ -122,7 +121,7 @@
*
*/
public class ParForProgramBlock extends ForProgramBlock {
protected static final Log LOG = LogFactory.getLog(CommonThreadPool.class.getName());
protected static final Log LOG = LogFactory.getLog(ParForProgramBlock.class.getName());
// execution modes
public enum PExecMode {
LOCAL, //local (master) multi-core execution mode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,6 @@ public long getMaxIndexInRange(int dim) {
*/
public void forEachParallel(BiFunction<FederatedRange, FederatedData, Void> forEachFunction) {
ExecutorService pool = CommonThreadPool.get(_fedMap.size());

ArrayList<MappingTask> mappingTasks = new ArrayList<>();
for(Pair<FederatedRange, FederatedData> fedMap : _fedMap)
mappingTasks.add(new MappingTask(fedMap.getKey(), fedMap.getValue(), forEachFunction, _ID));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.util.CommonThreadPool;
Expand Down Expand Up @@ -91,7 +90,7 @@ private void computeEpoch(long dataSize, int batchIter) {
ListObject params = pullModel();
Future<ListObject> accGradients = ConcurrentUtils.constantFuture(null);
if(_tpool == null)
_tpool = CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism());
_tpool = CommonThreadPool.get();

try {
for (int j = 0; j < batchIter; j++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.codegen.CodegenUtils;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.frame.data.columns.Array;
import org.apache.sysds.runtime.frame.data.columns.ArrayFactory;
Expand Down Expand Up @@ -860,26 +859,25 @@ private double arraysSizeInMemory() {
size += ArrayFactory.getInMemorySize(_schema[j], rlen, true);
else {// allocated
if(rlen > 1000 && clen > 10 && ConfigurationManager.isParallelIOEnabled()) {
final ExecutorService pool = CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism());
final ExecutorService pool = CommonThreadPool.get();
try {
size += pool.submit(() -> {
return Arrays.stream(_coldata).parallel() // parallel columns
.map(x -> x.getInMemorySize()).reduce(0L, Long::sum);
}).get();
pool.shutdown();

}
catch(InterruptedException | ExecutionException e) {
pool.shutdown();
LOG.error(e);
for(Array<?> aa : _coldata)
size += aa.getInMemorySize();
}
finally{
pool.shutdown();
}
}
else {
for(Array<?> aa : _coldata)
size += aa.getInMemorySize();

}
}
return size;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ protected void readJSONLFrameFromHDFS(Path path, JobConf jobConf, FileSystem fil
splits = IOUtilFunctions.sortInputSplits(splits);

try{
ExecutorService executorPool = CommonThreadPool.get(Math.min(numThreads, splits.length));
ExecutorService executorPool = CommonThreadPool.get(numThreads);

//compute num rows per split
ArrayList<CountRowsTask> countRowsTasks = new ArrayList<>();
Expand Down
19 changes: 4 additions & 15 deletions src/main/java/org/apache/sysds/runtime/iogen/FormatIdentifyer.java
Original file line number Diff line number Diff line change
Expand Up @@ -647,19 +647,14 @@ private Pair<ArrayList<String>, HashSet<String>> buildValueKeyPattern() {
colIndexes.add(0);

try {
ExecutorService pool = CommonThreadPool.get(1);
ArrayList<BuildColsKeyPatternSingleRowTask> tasks = new ArrayList<>();
tasks.add(
new BuildColsKeyPatternSingleRowTask(prefixesRemovedReverse, prefixesRemoved, prefixes, suffixes,
prefixesRemovedReverseSort, keys, colSuffixes, lcs, colIndexes));

//wait until all tasks have been executed
List<Future<Object>> rt = pool.invokeAll(tasks);
pool.shutdown();

//check for exceptions
for(Future<Object> task : rt)
task.get();
for(Callable<Object> task : tasks)
task.call();
}
catch(Exception e) {
throw new RuntimeException("Failed BuildValueKeyPattern.", e);
Expand Down Expand Up @@ -770,19 +765,13 @@ private Pair<ArrayList<String>, HashSet<String>> buildIndexKeyPattern(boolean ke
colIndexe.add(0);

try {
ExecutorService pool = CommonThreadPool.get(1);
ArrayList<BuildColsKeyPatternSingleRowTask> tasks = new ArrayList<>();
tasks.add(
new BuildColsKeyPatternSingleRowTask(prefixesRemovedReverse, prefixesRemoved, prefixes, suffixes,
prefixesRemovedReverseSort, keys, colSuffixes, lcs, colIndexe));

//wait until all tasks have been executed
List<Future<Object>> rt = pool.invokeAll(tasks);
pool.shutdown();

//check for exceptions
for(Future<Object> task : rt)
task.get();
for(Callable<Object> task : tasks)
task.call();
}
catch(Exception e) {
throw new RuntimeException("Failed BuildValueKeyPattern.", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ private static long execute(ArrayList<Callable<Long>> tasks, DnnParameters param
}
}
else {
ExecutorService pool = CommonThreadPool.get( Math.min(k, params.N) );
ExecutorService pool = CommonThreadPool.get(k);
List<Future<Long>> taskret = pool.invokeAll(tasks);
pool.shutdown();
for( Future<Long> task : taskret )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
import java.util.concurrent.Future;
import java.util.stream.IntStream;

import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.concurrent.ConcurrentUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand All @@ -54,7 +54,6 @@
import org.apache.sysds.runtime.compress.lib.CLALibAggTernaryOp;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.DenseBlockFP64;
import org.apache.sysds.runtime.data.DenseBlockFactory;
Expand Down Expand Up @@ -373,7 +372,7 @@ public final MatrixBlock allocateDenseBlock() {
}

public Future<MatrixBlock> allocateBlockAsync() {
ExecutorService pool = CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism());
ExecutorService pool = CommonThreadPool.get();
return (pool != null) ? pool.submit(() -> allocateBlock()) : //async
ConcurrentUtils.constantFuture(allocateBlock()); //fallback sync
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ public static ExecutorService get() {
* @param k The number of threads wanted
* @return The executor with specified parallelism
*/
public static ExecutorService get(int k) {
public synchronized static ExecutorService get(int k) {
if(size == k)
return shared;
else if(Thread.currentThread().getName().equals("main")) {
if(shared2 != null && shared2K == k)
return shared2;
else if(shared2 == null) {
shared2 = new CommonThreadPool(Executors.newFixedThreadPool(k));
shared2 = new CommonThreadPool(new ForkJoinPool(k));
shared2K = k;
return shared2;
}
Expand Down Expand Up @@ -141,12 +141,13 @@ public static <T> void invokeAndShutdown(ExecutorService pool, Collection<? exte
// check for errors and exceptions
for(Future<T> r : ret)
r.get();
// shutdown pool
pool.shutdown();
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
finally{
pool.shutdown();
}
}

/**
Expand All @@ -155,8 +156,8 @@ public static <T> void invokeAndShutdown(ExecutorService pool, Collection<? exte
*
* @return A dynamic thread pool.
*/
public static ExecutorService getDynamicPool() {
if(asyncPool != null)
public synchronized static ExecutorService getDynamicPool() {
if(asyncPool != null && !(asyncPool.isShutdown() || asyncPool.isTerminated()) )
return asyncPool;
else {
asyncPool = Executors.newCachedThreadPool();
Expand All @@ -167,7 +168,7 @@ public static ExecutorService getDynamicPool() {
/**
* Shutdown the cached thread pools.
*/
public static void shutdownAsyncPools() {
public synchronized static void shutdownAsyncPools() {
if(asyncPool != null) {
// shutdown prefetch/broadcast thread pool
asyncPool.shutdown();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
import java.io.OutputStream;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import org.apache.commons.lang3.concurrent.ConcurrentUtils;

public class DoubleBufferingOutputStream extends FilterOutputStream
{
protected ExecutorService _pool = CommonThreadPool.get(1);
public class DoubleBufferingOutputStream extends FilterOutputStream {
protected ExecutorService _pool = Executors.newSingleThreadExecutor();
protected Future<?>[] _locks;
protected byte[][] _buff;
private int _pos;
Expand Down
3 changes: 2 additions & 1 deletion src/test/java/org/apache/sysds/performance/Main.java
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,11 @@ private static void run11(String[] args, int id) throws InterruptedException, Ex
public static void main(String[] args) {
try {
exec(Integer.parseInt(args[0]), args);
CommonThreadPool.get().shutdown();
}
catch(Exception e) {
e.printStackTrace();
}finally{
CommonThreadPool.get().shutdown();
}
}
}
Loading

0 comments on commit 649fc8e

Please sign in to comment.