diff --git a/core/src/main/java/jox/Channel.java b/core/src/main/java/jox/Channel.java index 2dc9229..8e3677f 100644 --- a/core/src/main/java/jox/Channel.java +++ b/core/src/main/java/jox/Channel.java @@ -6,8 +6,6 @@ import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.concurrent.locks.LockSupport; -import static jox.CellState.*; - public class Channel { /* Inspired by the "Fast and Scalable Channels in Kotlin Coroutines" paper (https://arxiv.org/abs/2211.04986), and @@ -89,7 +87,7 @@ private boolean updateCellSend(long s, T value) throws InterruptedException { // storing the value to send as the continuation's payload, so that the receiver can use it var c = new Continuation(value); if (casState(s, null, c)) { - c.await(() -> setState(INTERRUPTED, s)); + c.await(() -> setState(CellState.INTERRUPTED, s)); return true; } // else: CAS unsuccessful, repeat @@ -104,14 +102,14 @@ private boolean updateCellSend(long s, T value) throws InterruptedException { case Continuation c -> { // a receiver is waiting -> trying to resume if (c.tryResume(value)) { - setState(DONE, s); + setState(CellState.DONE, s); return true; } else { // cell interrupted -> trying with a new one return false; } } - case INTERRUPTED, BROKEN -> { + case CellState.INTERRUPTED, CellState.BROKEN -> { // cell interrupted or poisoned -> trying with a new one return false; } @@ -162,11 +160,11 @@ private Object updateCellReceive(long r) throws InterruptedException { // not using any payload var c = new Continuation(null); if (casState(r, null, c)) { - return c.await(() -> setState(INTERRUPTED, r)); + return c.await(() -> setState(CellState.INTERRUPTED, r)); } // else: CAS unsuccessful, repeat } else { - if (casState(r, null, BROKEN)) { + if (casState(r, null, CellState.BROKEN)) { return UpdateCellReceiveResult.RESTART; } // else: CAS unsuccessful, repeat @@ -175,7 +173,7 @@ private Object updateCellReceive(long r) throws InterruptedException { case Continuation c -> { // a sender is waiting -> trying to resume if (c.tryResume(0)) { - setState(DONE, r); + setState(CellState.DONE, r); return c.getPayload(); } else { // cell interrupted -> trying with a new one @@ -186,7 +184,7 @@ private Object updateCellReceive(long r) throws InterruptedException { // an elimination has happened -> finish return b.value(); } - case INTERRUPTED -> { + case CellState.INTERRUPTED -> { // cell interrupted -> trying with a new one return UpdateCellReceiveResult.RESTART; } @@ -194,100 +192,111 @@ private Object updateCellReceive(long r) throws InterruptedException { } } } -} -// possible return values of updateCellReceive: one of the enum constants below, or the received value + // possible return values of updateCellReceive: one of the enum constants below, or the received value -enum UpdateCellReceiveResult { - RESTART -} + private enum UpdateCellReceiveResult { + RESTART + } -// possible states of a cell: one of the enum constants below, Buffered, or Continuation + // possible states of a cell: one of the enum constants below, Buffered, or Continuation -enum CellState { - DONE, - INTERRUPTED, - BROKEN; -} + private enum CellState { + DONE, + INTERRUPTED, + BROKEN; + } -// a java record called Buffered with a single value field; the type should be T -record Buffered(Object value) {} + // a java record called Buffered with a single value field; the type should be T + private record Buffered(Object value) {} -final class Continuation { - /** - * The number of busy-looping iterations before yielding, during {@link Continuation#await(Runnable)}. {@code 0}, if there's a single CPU. - */ - private static final int SPINS = Runtime.getRuntime().availableProcessors() == 1 ? 0 : 10000; + private static final class Continuation { + /** + * The number of busy-looping iterations before yielding, during {@link Continuation#await(Runnable)}. {@code 0}, if there's a single CPU. + */ + private static final int SPINS = Runtime.getRuntime().availableProcessors() == 1 ? 0 : 10000; - private final Thread creatingThread; - private volatile Object data; // set using DATA var handle + private final Thread creatingThread; + private volatile Object data; // set using DATA var handle - private final Object payload; + private final Object payload; - Continuation(Object payload) { - this.payload = payload; - this.creatingThread = Thread.currentThread(); - } + Continuation(Object payload) { + this.payload = payload; + this.creatingThread = Thread.currentThread(); + } - /** - * Resume the continuation with the given value. - * - * @param value Should not be {@code null}. - * @return {@code true} tf the continuation was resumed successfully. {@code false} if it was interrupted. - */ - boolean tryResume(Object value) { - var result = Continuation.DATA.compareAndSet(this, null, value); - LockSupport.unpark(creatingThread); - return result; - } + /** + * Resume the continuation with the given value. + * + * @param value Should not be {@code null}. + * @return {@code true} tf the continuation was resumed successfully. {@code false} if it was interrupted. + */ + boolean tryResume(Object value) { + var result = Continuation.DATA.compareAndSet(this, null, value); + LockSupport.unpark(creatingThread); + return result; + } - /** - * Await for the continuation to be resumed. - * - * @param onInterrupt - * @return The value with which the continuation was resumed. - */ - Object await(Runnable onInterrupt) throws InterruptedException { - var spinIterations = SPINS; - while (data == null) { - if (spinIterations > 0) { - Thread.onSpinWait(); - spinIterations -= 1; - } else { - LockSupport.park(); - - if (Thread.interrupted()) { -// Continuation.STATE.compareAndSet(this, 0, 2); // TODO if - var e = new InterruptedException(); - - try { - onInterrupt.run(); - } catch (Throwable ee) { - e.addSuppressed(ee); + /** + * Await for the continuation to be resumed. + * + * @param onInterrupt + * @return The value with which the continuation was resumed. + */ + Object await(Runnable onInterrupt) throws InterruptedException { + var spinIterations = SPINS; + while (data == null) { + if (spinIterations > 0) { + Thread.onSpinWait(); + spinIterations -= 1; + } else { + LockSupport.park(); + + if (Thread.interrupted()) { + // potential race with `tryResume` + if (Continuation.DATA.compareAndSet(this, null, ContinuationMarker.INTERRUPTED)) { + var e = new InterruptedException(); + + try { + onInterrupt.run(); + } catch (Throwable ee) { + e.addSuppressed(ee); + } + + throw e; + } else { + // another thread already set the data; setting the interrupt status (so that the next blocking + // operation throws), and continuing + Thread.currentThread().interrupt(); + } } - - throw e; } } - } - return data; - } + return data; + } - Object getPayload() { - return payload; - } + Object getPayload() { + return payload; + } - // + // - private static final VarHandle DATA; + private static final VarHandle DATA; - static { - var l = MethodHandles.lookup(); - try { - DATA = l.findVarHandle(Continuation.class, "data", Object.class); - } catch (ReflectiveOperationException e) { - throw new ExceptionInInitializerError(e); + static { + var l = MethodHandles.lookup(); + try { + DATA = l.findVarHandle(Continuation.class, "data", Object.class); + } catch (ReflectiveOperationException e) { + throw new ExceptionInInitializerError(e); + } } } + + // the marker value is used only to mark in the continuation's `data` that interruption won the race with `tryResume` + private enum ContinuationMarker { + INTERRUPTED + } } \ No newline at end of file diff --git a/core/src/test/java/jox/ChannelInterruptionTest.java b/core/src/test/java/jox/ChannelInterruptionTest.java new file mode 100644 index 0000000..0396cb3 --- /dev/null +++ b/core/src/test/java/jox/ChannelInterruptionTest.java @@ -0,0 +1,68 @@ +package jox; + +import org.junit.jupiter.api.Test; + +import static jox.TestUtil.*; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ChannelInterruptionTest { + @Test + public void testSendReceiveAfterSendInterrupt() throws Exception { + // given + Channel channel = new Channel<>(); + + // when + scoped(scope -> { + var t1 = forkCancelable(scope, () -> channel.send("x")); + t1.cancel(); + + forkVoid(scope, () -> channel.send("y")); + var t3 = fork(scope, channel::receive); + + // then + assertEquals("y", t3.get()); + }); + } + + @Test + public void testSendReceiveAfterReceiveInterrupt() throws Exception { + // given + Channel channel = new Channel<>(); + + // when + scoped(scope -> { + var t1 = forkCancelable(scope, channel::receive); + t1.cancel(); + + forkVoid(scope, () -> channel.send("x")); + var t3 = fork(scope, channel::receive); + + // then + assertEquals("x", t3.get()); + }); + } + + @Test + public void testRaceInterruptAndSend() throws Exception { + // when + scoped(scope -> { + for (int i = 0; i < 100; i++) { + // given + Channel channel = new Channel<>(); + + var t1 = forkCancelable(scope, () -> channel.send("x")); + var t2 = fork(scope, channel::receive); + t1.cancel(); + + if (t1.cancel() instanceof InterruptedException) { + // the `receive` from t2 has not happened yet + forkVoid(scope, () -> channel.send("y")); + assertEquals("y", t2.get()); + } else { + // the `receive` from t2 has already happened + assertEquals("x", t2.get()); + } + } + }); + } +} diff --git a/core/src/test/java/jox/ChannelTest.java b/core/src/test/java/jox/ChannelTest.java index ae97e12..ae47313 100644 --- a/core/src/test/java/jox/ChannelTest.java +++ b/core/src/test/java/jox/ChannelTest.java @@ -5,6 +5,7 @@ import java.util.HashSet; import java.util.concurrent.*; +import static jox.TestUtil.*; import static org.junit.jupiter.api.Assertions.assertEquals; public class ChannelTest { @@ -97,54 +98,4 @@ public void performanceTest() throws Exception { }); } } - - // - - private void scoped(ConsumerWithException> f) throws InterruptedException, ExecutionException { - try (var scope = new StructuredTaskScope.ShutdownOnFailure()) { - // making sure everything runs in a VT - scope.fork(() -> { - f.accept(scope); - return null; - }); - scope.join().throwIfFailed(); - } - } - - private Future fork(StructuredTaskScope scope, Callable c) { - var f = new CompletableFuture(); - scope.fork(() -> { - try { - f.complete(c.call()); - } catch (Exception ex) { - f.completeExceptionally(ex); - } - return null; - }); - return f; - } - - private Future forkVoid(StructuredTaskScope scope, RunnableWithException r) { - return fork(scope, () -> { - r.run(); - return null; - }); - } - - @FunctionalInterface - private interface ConsumerWithException { - void accept(T o) throws Exception; - } - - @FunctionalInterface - private interface RunnableWithException { - void run() throws Exception; - } - - private void timed(String label, RunnableWithException block) throws Exception { - var start = System.nanoTime(); - block.run(); - var end = System.nanoTime(); - System.out.println(label + " took: " + (end - start) / 1_000_000 + " ms"); - } } \ No newline at end of file diff --git a/core/src/test/java/jox/TestUtil.java b/core/src/test/java/jox/TestUtil.java new file mode 100644 index 0000000..47a3a4c --- /dev/null +++ b/core/src/test/java/jox/TestUtil.java @@ -0,0 +1,112 @@ +package jox; + +import java.util.concurrent.*; + +public class TestUtil { + public static void scoped(ConsumerWithException> f) throws InterruptedException, ExecutionException { + try (var scope = new StructuredTaskScope.ShutdownOnFailure()) { + // making sure everything runs in a VT + scope.fork(() -> { + f.accept(scope); + return null; + }); + scope.join().throwIfFailed(); + } + } + + public static Future fork(StructuredTaskScope scope, Callable c) { + var f = new CompletableFuture(); + scope.fork(() -> { + try { + f.complete(c.call()); + } catch (Exception ex) { + f.completeExceptionally(ex); + } + return null; + }); + return f; + } + + public static Fork forkCancelable(StructuredTaskScope scope, RunnableWithException c) { + return forkCancelable(scope, () -> { + c.run(); + return null; + }); + } + + public static Fork forkCancelable(StructuredTaskScope scope, Callable c) { + var f = new CompletableFuture(); + var t = Thread.ofVirtual().start(() -> { + try { + f.complete(c.call()); + } catch (Exception ex) { + f.completeExceptionally(ex); + } + }); + // supervisor + scope.fork(() -> { + try { + f.get(); + } catch (InterruptedException e) { + t.interrupt(); + } catch (ExecutionException e) { + if (!(e.getCause() instanceof InterruptedException)) { + throw e; + } // else ignore, already interrupted + } finally { + t.join(); + } + return null; + }); + return new Fork<>() { + @Override + public T get() throws ExecutionException, InterruptedException { + return f.get(); + } + + @Override + public Object cancel() throws InterruptedException, ExecutionException { + t.interrupt(); + t.join(); + if (f.isCompletedExceptionally()) { + return f.exceptionNow(); + } else { + return f.get(); + } + } + }; + } + + public static Future forkVoid(StructuredTaskScope scope, RunnableWithException r) { + return fork(scope, () -> { + r.run(); + return null; + }); + } + + @FunctionalInterface + public static interface ConsumerWithException { + void accept(T o) throws Exception; + } + + @FunctionalInterface + public static interface RunnableWithException { + void run() throws Exception; + } + + public static interface Fork { + T get() throws ExecutionException, InterruptedException; + + /** + * Either an exception, or T. Waits for the fork to complete. + */ + Object cancel() throws InterruptedException, ExecutionException; + } + + public static void timed(String label, RunnableWithException block) throws Exception { + var start = System.nanoTime(); + block.run(); + var end = System.nanoTime(); + System.out.println(label + " took: " + (end - start) / 1_000_000 + " ms"); + } +}