Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MINOR] Fix ThreadPool for Federated #1877

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/main/java/org/apache/sysds/api/DMLScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ public static boolean executeScript( Configuration conf, String[] args )
//reset runtime platform and visualize flag
setGlobalExecMode(oldrtplatform);
EXPLAIN = oldexplain;
CommonThreadPool.shutdownAsyncPools();
}

return true;
Expand Down Expand Up @@ -572,9 +573,6 @@ public static void cleanupHadoopExecution( DMLConfig config )
//0) cleanup federated workers if necessary
FederatedData.clearFederatedWorkers();

//0) shutdown prefetch/broadcast thread pool if necessary
CommonThreadPool.shutdownAsyncPools();

//1) cleanup scratch space (everything for current uuid)
//(required otherwise export to hdfs would skip assumed unnecessary writes if same name)
HDFSTool.deleteFileIfExistOnHDFS( config.getTextValue(DMLConfig.SCRATCH_SPACE) + dirSuffix );
Expand Down
5 changes: 2 additions & 3 deletions src/main/java/org/apache/sysds/conf/ConfigurationManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Compression.CompressConfig;
import org.apache.sysds.lops.compile.linearization.ILinearize;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.util.CommonThreadPool;

Expand Down Expand Up @@ -66,7 +65,7 @@ public class ConfigurationManager{
_dmlconf = new DMLConfig();
_cconf = new CompilerConfig();

final ExecutorService pool = CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism());
final ExecutorService pool = CommonThreadPool.get();
pool.submit(() ->{
try{
IOUtilFunctions.getFileSystem(_rJob);
Expand All @@ -75,7 +74,7 @@ public class ConfigurationManager{
LOG.warn(e.getMessage());
}
});
pool.shutdown();
// pool.shutdown();
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ else if(finalCols[c] == null) {
}
}

final ExecutorService pool = CommonThreadPool.get(Math.max(Math.min(clen / 500, k), 1));
final ExecutorService pool = CommonThreadPool.get();
try {

List<AColGroup> finalGroups = pool.submit(() -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,32 @@

package org.apache.sysds.runtime.controlprogram;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.log4j.Level;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.CompilerConfig;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.ParForStatementBlock;
Expand Down Expand Up @@ -91,24 +104,10 @@
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.ProgramConverter;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.stats.ParForStatistics;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;



/**
Expand All @@ -122,7 +121,7 @@
*
*/
public class ParForProgramBlock extends ForProgramBlock {
protected static final Log LOG = LogFactory.getLog(CommonThreadPool.class.getName());
protected static final Log LOG = LogFactory.getLog(ParForProgramBlock.class.getName());
// execution modes
public enum PExecMode {
LOCAL, //local (master) multi-core execution mode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,6 @@ public long getMaxIndexInRange(int dim) {
*/
public void forEachParallel(BiFunction<FederatedRange, FederatedData, Void> forEachFunction) {
ExecutorService pool = CommonThreadPool.get(_fedMap.size());

ArrayList<MappingTask> mappingTasks = new ArrayList<>();
for(Pair<FederatedRange, FederatedData> fedMap : _fedMap)
mappingTasks.add(new MappingTask(fedMap.getKey(), fedMap.getValue(), forEachFunction, _ID));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.util.CommonThreadPool;
Expand Down Expand Up @@ -91,7 +90,7 @@ private void computeEpoch(long dataSize, int batchIter) {
ListObject params = pullModel();
Future<ListObject> accGradients = ConcurrentUtils.constantFuture(null);
if(_tpool == null)
_tpool = CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism());
_tpool = CommonThreadPool.get();

try {
for (int j = 0; j < batchIter; j++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.codegen.CodegenUtils;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.frame.data.columns.Array;
import org.apache.sysds.runtime.frame.data.columns.ArrayFactory;
Expand Down Expand Up @@ -860,26 +859,25 @@ private double arraysSizeInMemory() {
size += ArrayFactory.getInMemorySize(_schema[j], rlen, true);
else {// allocated
if(rlen > 1000 && clen > 10 && ConfigurationManager.isParallelIOEnabled()) {
final ExecutorService pool = CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism());
final ExecutorService pool = CommonThreadPool.get();
try {
size += pool.submit(() -> {
return Arrays.stream(_coldata).parallel() // parallel columns
.map(x -> x.getInMemorySize()).reduce(0L, Long::sum);
}).get();
pool.shutdown();

}
catch(InterruptedException | ExecutionException e) {
pool.shutdown();
LOG.error(e);
for(Array<?> aa : _coldata)
size += aa.getInMemorySize();
}
finally{
pool.shutdown();
}
}
else {
for(Array<?> aa : _coldata)
size += aa.getInMemorySize();

}
}
return size;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ protected void readJSONLFrameFromHDFS(Path path, JobConf jobConf, FileSystem fil
splits = IOUtilFunctions.sortInputSplits(splits);

try{
ExecutorService executorPool = CommonThreadPool.get(Math.min(numThreads, splits.length));
ExecutorService executorPool = CommonThreadPool.get(numThreads);

//compute num rows per split
ArrayList<CountRowsTask> countRowsTasks = new ArrayList<>();
Expand Down
19 changes: 4 additions & 15 deletions src/main/java/org/apache/sysds/runtime/iogen/FormatIdentifyer.java
Original file line number Diff line number Diff line change
Expand Up @@ -647,19 +647,14 @@ private Pair<ArrayList<String>, HashSet<String>> buildValueKeyPattern() {
colIndexes.add(0);

try {
ExecutorService pool = CommonThreadPool.get(1);
ArrayList<BuildColsKeyPatternSingleRowTask> tasks = new ArrayList<>();
tasks.add(
new BuildColsKeyPatternSingleRowTask(prefixesRemovedReverse, prefixesRemoved, prefixes, suffixes,
prefixesRemovedReverseSort, keys, colSuffixes, lcs, colIndexes));

//wait until all tasks have been executed
List<Future<Object>> rt = pool.invokeAll(tasks);
pool.shutdown();

//check for exceptions
for(Future<Object> task : rt)
task.get();
for(Callable<Object> task : tasks)
task.call();
}
catch(Exception e) {
throw new RuntimeException("Failed BuildValueKeyPattern.", e);
Expand Down Expand Up @@ -770,19 +765,13 @@ private Pair<ArrayList<String>, HashSet<String>> buildIndexKeyPattern(boolean ke
colIndexe.add(0);

try {
ExecutorService pool = CommonThreadPool.get(1);
ArrayList<BuildColsKeyPatternSingleRowTask> tasks = new ArrayList<>();
tasks.add(
new BuildColsKeyPatternSingleRowTask(prefixesRemovedReverse, prefixesRemoved, prefixes, suffixes,
prefixesRemovedReverseSort, keys, colSuffixes, lcs, colIndexe));

//wait until all tasks have been executed
List<Future<Object>> rt = pool.invokeAll(tasks);
pool.shutdown();

//check for exceptions
for(Future<Object> task : rt)
task.get();
for(Callable<Object> task : tasks)
task.call();
}
catch(Exception e) {
throw new RuntimeException("Failed BuildValueKeyPattern.", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ private static long execute(ArrayList<Callable<Long>> tasks, DnnParameters param
}
}
else {
ExecutorService pool = CommonThreadPool.get( Math.min(k, params.N) );
ExecutorService pool = CommonThreadPool.get(k);
List<Future<Long>> taskret = pool.invokeAll(tasks);
pool.shutdown();
for( Future<Long> task : taskret )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
import java.util.concurrent.Future;
import java.util.stream.IntStream;

import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.concurrent.ConcurrentUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand All @@ -54,7 +54,6 @@
import org.apache.sysds.runtime.compress.lib.CLALibAggTernaryOp;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.DenseBlockFP64;
import org.apache.sysds.runtime.data.DenseBlockFactory;
Expand Down Expand Up @@ -373,7 +372,7 @@ public final MatrixBlock allocateDenseBlock() {
}

public Future<MatrixBlock> allocateBlockAsync() {
ExecutorService pool = CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism());
ExecutorService pool = CommonThreadPool.get();
return (pool != null) ? pool.submit(() -> allocateBlock()) : //async
ConcurrentUtils.constantFuture(allocateBlock()); //fallback sync
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ public static ExecutorService get() {
* @param k The number of threads wanted
* @return The executor with specified parallelism
*/
public static ExecutorService get(int k) {
public synchronized static ExecutorService get(int k) {
if(size == k)
return shared;
else if(Thread.currentThread().getName().equals("main")) {
if(shared2 != null && shared2K == k)
return shared2;
else if(shared2 == null) {
shared2 = new CommonThreadPool(Executors.newFixedThreadPool(k));
shared2 = new CommonThreadPool(new ForkJoinPool(k));
shared2K = k;
return shared2;
}
Expand Down Expand Up @@ -141,12 +141,13 @@ public static <T> void invokeAndShutdown(ExecutorService pool, Collection<? exte
// check for errors and exceptions
for(Future<T> r : ret)
r.get();
// shutdown pool
pool.shutdown();
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
finally{
pool.shutdown();
}
}

/**
Expand All @@ -155,8 +156,8 @@ public static <T> void invokeAndShutdown(ExecutorService pool, Collection<? exte
*
* @return A dynamic thread pool.
*/
public static ExecutorService getDynamicPool() {
if(asyncPool != null)
public synchronized static ExecutorService getDynamicPool() {
if(asyncPool != null && !(asyncPool.isShutdown() || asyncPool.isTerminated()) )
return asyncPool;
else {
asyncPool = Executors.newCachedThreadPool();
Expand All @@ -167,7 +168,7 @@ public static ExecutorService getDynamicPool() {
/**
* Shutdown the cached thread pools.
*/
public static void shutdownAsyncPools() {
public synchronized static void shutdownAsyncPools() {
if(asyncPool != null) {
// shutdown prefetch/broadcast thread pool
asyncPool.shutdown();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
import java.io.OutputStream;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import org.apache.commons.lang3.concurrent.ConcurrentUtils;

public class DoubleBufferingOutputStream extends FilterOutputStream
{
protected ExecutorService _pool = CommonThreadPool.get(1);
public class DoubleBufferingOutputStream extends FilterOutputStream {
protected ExecutorService _pool = Executors.newSingleThreadExecutor();
protected Future<?>[] _locks;
protected byte[][] _buff;
private int _pos;
Expand Down
3 changes: 2 additions & 1 deletion src/test/java/org/apache/sysds/performance/Main.java
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,11 @@ private static void run11(String[] args, int id) throws InterruptedException, Ex
public static void main(String[] args) {
try {
exec(Integer.parseInt(args[0]), args);
CommonThreadPool.get().shutdown();
}
catch(Exception e) {
e.printStackTrace();
}finally{
CommonThreadPool.get().shutdown();
}
}
}
Loading
Loading