diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java index ff70b330aea..ddc5ee25174 100644 --- a/src/main/java/org/apache/sysds/api/DMLScript.java +++ b/src/main/java/org/apache/sysds/api/DMLScript.java @@ -573,7 +573,7 @@ public static void cleanupHadoopExecution( DMLConfig config ) FederatedData.clearFederatedWorkers(); //0) shutdown prefetch/broadcast thread pool if necessary - CommonThreadPool.shutdownAsyncRDDPool(); + CommonThreadPool.shutdownAsyncPools(); //1) cleanup scratch space (everything for current uuid) //(required otherwise export to hdfs would skip assumed unnecessary writes if same name) diff --git a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java index 0ff071ebddd..cc6483d2588 100644 --- a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java +++ b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java @@ -30,7 +30,6 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.DMLRuntimeException; @@ -62,21 +61,23 @@ public class CommonThreadPool implements ExecutorService { */ private static final ExecutorService shared = ForkJoinPool.commonPool(); /** A secondary thread local executor that use a custom number of threads */ - private static ExecutorService shared2 = null; + private static CommonThreadPool shared2 = null; /** The number of threads used in the custom secondary executor */ private static int shared2K = -1; /** Dynamic thread pool, that dynamically allocate threads as tasks come in. */ - private static ExecutorService triggerRemoteOPsPool = null; + private static ExecutorService asyncPool = null; /** This common thread pool */ private final ExecutorService _pool; /** - * Private constructor of the threadPool. + * Constructor of the threadPool. + * This is intended not to be used except for tests. + * Please use the static constructors. * * @param pool The thread pool instance to use. */ - private CommonThreadPool(ExecutorService pool) { - _pool = pool; + public CommonThreadPool(ExecutorService pool) { + this._pool = pool; } /** @@ -109,12 +110,11 @@ else if(shared2 == null) { shared2K = k; return shared2; } - else { - return Executors.newFixedThreadPool(k); - } + else + return new CommonThreadPool(Executors.newFixedThreadPool(k)); } else - return Executors.newFixedThreadPool(k); + return new CommonThreadPool(Executors.newFixedThreadPool(k)); } /** @@ -124,7 +124,7 @@ else if(shared2 == null) { * @return If we have a cached thread pool. */ public static boolean isSharedTPThreads(int k) { - return InfrastructureAnalyzer.getLocalParallelism() == k || shared2K == k || shared2K == -1; + return size == k || shared2K == k || shared2K == -1; } /** @@ -156,27 +156,33 @@ public static void invokeAndShutdown(ExecutorService pool, Collection shutdownNow() { - return !isCached() ? null : _pool.shutdownNow(); + return !isCached() ? _pool.shutdownNow() : null; } @Override @@ -221,30 +227,29 @@ public Future submit(Runnable task) { return _pool.submit(task); } - // unnecessary methods required for API compliance @Override public boolean isShutdown() { - throw new NotImplementedException(); + return isCached() || _pool.isShutdown(); } @Override public boolean isTerminated() { - throw new NotImplementedException(); + return isCached() || _pool.isTerminated(); } @Override public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { - throw new NotImplementedException(); + return isCached() || _pool.awaitTermination(timeout, unit); } @Override public T invokeAny(Collection> tasks) throws InterruptedException, ExecutionException { - throw new NotImplementedException(); + return _pool.invokeAny(tasks); } @Override public T invokeAny(Collection> tasks, long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { - throw new NotImplementedException(); + return _pool.invokeAny(tasks); } } diff --git a/src/test/java/org/apache/sysds/test/component/misc/ThreadPool.java b/src/test/java/org/apache/sysds/test/component/misc/ThreadPool.java index 9bfdee72a9a..ca79e8800b3 100644 --- a/src/test/java/org/apache/sysds/test/component/misc/ThreadPool.java +++ b/src/test/java/org/apache/sysds/test/component/misc/ThreadPool.java @@ -20,59 +20,389 @@ package org.apache.sysds.test.component.misc; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysds.runtime.util.CommonThreadPool; import org.junit.Test; public class ThreadPool { + protected static final Log LOG = LogFactory.getLog(ThreadPool.class.getName()); + @Test public void testGetTheSame() { + CommonThreadPool.shutdownAsyncPools(); ExecutorService x = CommonThreadPool.get(); ExecutorService y = CommonThreadPool.get(); x.shutdown(); y.shutdown(); assertEquals(x, y); + CommonThreadPool.shutdownAsyncPools(); + CommonThreadPool.shutdownAsyncPools(); } @Test public void testGetSameCustomThreadCount() { + CommonThreadPool.shutdownAsyncPools(); + // choosing 7 because the machine is unlikely to have 7 logical cores. + String name = Thread.currentThread().getName(); + Thread.currentThread().setName("main"); + ExecutorService x = CommonThreadPool.get(7); + ExecutorService y = CommonThreadPool.get(7); + x.shutdown(); + y.shutdown(); + + Thread.currentThread().setName(name); + assertEquals(x, y); + CommonThreadPool.shutdownAsyncPools(); + CommonThreadPool.shutdownAsyncPools(); + + } + + @Test + public void testGetSameCustomThreadCountExecute() throws InterruptedException, ExecutionException { + // choosing 7 because the machine is unlikely to have 7 logical cores. + CommonThreadPool.shutdownAsyncPools(); + String name = Thread.currentThread().getName(); + Thread.currentThread().setName("main"); + ExecutorService x = CommonThreadPool.get(7); + ExecutorService y = CommonThreadPool.get(7); + assertEquals(x, y); + int v = x.submit(() -> 5).get(); + x.shutdown(); + int v2 = y.submit(() -> 5).get(); + y.shutdown(); + + Thread.currentThread().setName(name); + assertEquals(x, y); + assertEquals(v, v2); + CommonThreadPool.shutdownAsyncPools(); + } + + @Test + public void testGetSameCustomThreadCountExecuteV2() throws InterruptedException, ExecutionException { + // choosing 7 because the machine is unlikely to have 7 logical cores. + String name = Thread.currentThread().getName(); + Thread.currentThread().setName("main"); + ExecutorService x = CommonThreadPool.get(7); + ExecutorService y = CommonThreadPool.get(7); + assertEquals(x, y); + int v = x.submit(() -> 5).get(); + int v2 = y.submit(() -> 5).get(); + x.shutdown(); + y.shutdown(); + + Thread.currentThread().setName(name); + assertEquals(x, y); + assertEquals(v, v2); + CommonThreadPool.shutdownAsyncPools(); + } + + @Test + public void testGetSameCustomThreadCountExecuteV3() throws InterruptedException, ExecutionException { // choosing 7 because the machine is unlikely to have 7 logical cores. String name = Thread.currentThread().getName(); Thread.currentThread().setName("main"); ExecutorService x = CommonThreadPool.get(7); ExecutorService y = CommonThreadPool.get(7); + assertEquals(x, y); x.shutdown(); y.shutdown(); + int v = x.submit(() -> 5).get(); + int v2 = y.submit(() -> 5).get(); Thread.currentThread().setName(name); assertEquals(x, y); + assertEquals(v, v2); + CommonThreadPool.shutdownAsyncPools(); + } + @Test + public void testGetSameCustomThreadCountExecuteV4() throws InterruptedException, ExecutionException { + // choosing 7 because the machine is unlikely to have 7 logical cores. + String name = Thread.currentThread().getName(); + Thread.currentThread().setName("main"); + CommonThreadPool.shutdownAsyncPools(); + ExecutorService x = CommonThreadPool.get(5); + ExecutorService y = CommonThreadPool.get(7); + assertNotEquals(x, y); + x.shutdown(); + int v = x.submit(() -> 5).get(); + int v2 = y.submit(() -> 5).get(); + y.shutdown(); + + Thread.currentThread().setName(name); + assertEquals(v, v2); + CommonThreadPool.shutdownAsyncPools(); } @Test public void testFromOtherThread() throws InterruptedException, ExecutionException { + CommonThreadPool.shutdownAsyncPools(); ExecutorService x = CommonThreadPool.get(5); Future a = x.submit(() -> CommonThreadPool.get(5)); ExecutorService y = a.get(); - assertNotEquals(x, y); + CommonThreadPool.shutdownAsyncPools(); } @Test public void testFromOtherThreadInfrastructureParallelism() throws InterruptedException, ExecutionException { + CommonThreadPool.shutdownAsyncPools(); final int k = InfrastructureAnalyzer.getLocalParallelism(); ExecutorService x = CommonThreadPool.get(k); Future a = x.submit(() -> CommonThreadPool.get(k)); ExecutorService y = a.get(); assertEquals(x, y); + CommonThreadPool.shutdownAsyncPools(); + } + + @Test + public void dynamic() throws InterruptedException, ExecutionException { + CommonThreadPool.shutdownAsyncPools(); + final int k = InfrastructureAnalyzer.getLocalParallelism(); + ExecutorService x = CommonThreadPool.getDynamicPool(); + Future a = x.submit(() -> CommonThreadPool.get(k)); + ExecutorService y = a.get(); + assertNotEquals(x, y); + CommonThreadPool.shutdownAsyncPools(); + } + + @Test + public void dynamicSame() throws InterruptedException, ExecutionException { + CommonThreadPool.shutdownAsyncPools(); + ExecutorService x = CommonThreadPool.getDynamicPool(); + ExecutorService y = CommonThreadPool.getDynamicPool(); + assertEquals(x, y); + CommonThreadPool.shutdownAsyncPools(); + } + + @Test + public void isSharedTPThreads() throws InterruptedException, ExecutionException { + CommonThreadPool.shutdownAsyncPools(); + for(int i = 0; i < 10; i++) + assertTrue(CommonThreadPool.isSharedTPThreads(i)); + + CommonThreadPool.shutdownAsyncPools(); + } + + @Test + public void isSharedTPThreadsCommonSize() throws InterruptedException, ExecutionException { + CommonThreadPool.shutdownAsyncPools(); + assertTrue(CommonThreadPool.isSharedTPThreads(InfrastructureAnalyzer.getLocalParallelism())); + CommonThreadPool.shutdownAsyncPools(); + } + + @Test + public void isSharedTPThreadsFalse() throws InterruptedException, ExecutionException { + CommonThreadPool.shutdownAsyncPools(); + String name = Thread.currentThread().getName(); + Thread.currentThread().setName("main"); + CommonThreadPool.get(18); + for(int i = 1; i < 10; i++) + if(i != InfrastructureAnalyzer.getLocalParallelism()) + assertFalse("" + i, CommonThreadPool.isSharedTPThreads(i)); + assertTrue(CommonThreadPool.isSharedTPThreads(18)); + assertFalse(CommonThreadPool.isSharedTPThreads(19)); + + Thread.currentThread().setName(name); + CommonThreadPool.shutdownAsyncPools(); + } + + @Test + public void justWorks() throws InterruptedException, ExecutionException { + + String name = Thread.currentThread().getName(); + Thread.currentThread().setName("main"); + for(int j = 0; j < 2; j++) { + for(int i = 4; i < 17; i++) { + ExecutorService p = CommonThreadPool.get(i); + final Integer l = i; + assertEquals(l, p.submit(() -> l).get()); + p.shutdown(); + } + } + Thread.currentThread().setName(name); + } + + @Test + public void justWorksNotMain() throws InterruptedException, ExecutionException { + + for(int j = 0; j < 2; j++) { + + for(int i = 4; i < 10; i++) { + ExecutorService p = CommonThreadPool.get(i); + final Integer l = i; + assertEquals(l, p.submit(() -> l).get()); + p.shutdown(); + + } + } + } + + @Test + public void justWorksShutdownNow() throws InterruptedException, ExecutionException { + + String name = Thread.currentThread().getName(); + Thread.currentThread().setName("main"); + for(int j = 0; j < 2; j++) { + + for(int i = 4; i < 16; i++) { + ExecutorService p = CommonThreadPool.get(i); + final Integer l = i; + assertEquals(l, p.submit(() -> l).get()); + p.shutdownNow(); + + } + } + Thread.currentThread().setName(name); + } + + @Test + public void justWorksShutdownNowNotMain() throws InterruptedException, ExecutionException { + + for(int j = 0; j < 2; j++) { + + for(int i = 4; i < 16; i++) { + ExecutorService p = CommonThreadPool.get(i); + final Integer l = i; + assertEquals(l, p.submit(() -> l).get()); + p.shutdownNow(); + + } + } + } + + @Test + public void mock1() throws NoSuchFieldException, SecurityException, IllegalArgumentException, IllegalAccessException, + InterruptedException, ExecutionException, TimeoutException { + + ExecutorService p = mock(ExecutorService.class); + ExecutorService c = new CommonThreadPool(p); + + when(p.shutdownNow()).thenReturn(null); + assertNull(c.shutdownNow()); + + Collection> cc = (Collection>) null; + when(p.invokeAll(cc)).thenReturn(null); + assertNull(c.invokeAll(cc)); + when(p.invokeAll(cc, 1L, TimeUnit.DAYS)).thenReturn(null); + assertNull(c.invokeAll(cc, 1, TimeUnit.DAYS)); + doNothing().when(p).execute((Runnable) null); + c.execute((Runnable) null); + + when(p.submit((Callable) null)).thenReturn(null); + assertNull(c.submit((Callable) null)); + + when(p.submit((Runnable) null, null)).thenReturn(null); + assertNull(c.submit((Runnable) null, null)); + // when(tp.pool()).thenReturn(p); + + when(p.submit((Runnable) null)).thenReturn(null); + assertNull(c.submit((Runnable) null)); + + when(p.isShutdown()).thenReturn(false); + assertFalse(c.isShutdown()); + when(p.isShutdown()).thenReturn(true); + assertTrue(c.isShutdown()); + + when(p.isTerminated()).thenReturn(false); + assertFalse(c.isTerminated()); + when(p.isTerminated()).thenReturn(true); + assertTrue(c.isTerminated()); + + when(p.awaitTermination(10, TimeUnit.DAYS)).thenReturn(false); + assertFalse(c.awaitTermination(10, TimeUnit.DAYS)); + when(p.awaitTermination(10, TimeUnit.DAYS)).thenReturn(true); + assertTrue(c.awaitTermination(10, TimeUnit.DAYS)); + + when(p.invokeAny(cc)).thenReturn(null); + assertNull(c.invokeAny(cc)); + when(p.invokeAny(cc, 1L, TimeUnit.DAYS)).thenReturn(null); + assertNull(c.invokeAny(cc, 1, TimeUnit.DAYS)); + doNothing().when(p).execute((Runnable) null); + c.execute((Runnable) null); + } + @Test + public void mock2() throws NoSuchFieldException, SecurityException, IllegalArgumentException, IllegalAccessException, + InterruptedException, ExecutionException, TimeoutException { + + CommonThreadPool p = mock(CommonThreadPool.class); + when(p.isShutdown()).thenCallRealMethod(); + when(p.isTerminated()).thenCallRealMethod(); + when(p.awaitTermination(10, TimeUnit.DAYS)).thenCallRealMethod(); + when(p.isCached()).thenReturn(true); + assertTrue(p.isShutdown()); + assertTrue(p.isTerminated()); + assertTrue(p.awaitTermination(10, TimeUnit.DAYS)); + } + + @Test + public void coverEdge() { + ExecutorService a = CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism()); + assertTrue(new CommonThreadPool(a).isCached()); + } + + @Test(expected = DMLRuntimeException.class) + public void invokeAndShutdownException() throws InterruptedException { + ExecutorService p = mock(ExecutorService.class); + ExecutorService c = new CommonThreadPool(p); + + when(p.invokeAll(null)).thenThrow(new RuntimeException("Test")); + + CommonThreadPool.invokeAndShutdown(p, null); + + } + + @Test + public void invokeAndShutdown() throws InterruptedException { + + ExecutorService p = mock(ExecutorService.class); + ExecutorService c = new CommonThreadPool(p); + + Collection> cc = (Collection>) null; + when(p.invokeAll(cc)).thenReturn(new ArrayList>()); + + CommonThreadPool.invokeAndShutdown(c, null); + + } + + @Test + @SuppressWarnings("all") + public void invokeAndShutdownV2() throws InterruptedException{ + + ExecutorService p = mock(ExecutorService.class); + ExecutorService c = new CommonThreadPool(p); + + Collection> cc = (Collection>) null; + List> f = new ArrayList>(); + f.add(mock(Future.class)); + when(p.invokeAll(cc)).thenReturn(f ); + + CommonThreadPool.invokeAndShutdown(c, null); + + } }