Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw committed Nov 20, 2023
1 parent fd54ab5 commit 3ca06ae
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 134 deletions.
177 changes: 93 additions & 84 deletions core/src/main/java/jox/Channel.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import java.util.concurrent.atomic.AtomicReferenceArray;
import java.util.concurrent.locks.LockSupport;

import static jox.CellState.*;

public class Channel<T> {
/*
Inspired by the "Fast and Scalable Channels in Kotlin Coroutines" paper (https://arxiv.org/abs/2211.04986), and
Expand Down Expand Up @@ -89,7 +87,7 @@ private boolean updateCellSend(long s, T value) throws InterruptedException {
// storing the value to send as the continuation's payload, so that the receiver can use it
var c = new Continuation(value);
if (casState(s, null, c)) {
c.await(() -> setState(INTERRUPTED, s));
c.await(() -> setState(CellState.INTERRUPTED, s));
return true;
}
// else: CAS unsuccessful, repeat
Expand All @@ -104,14 +102,14 @@ private boolean updateCellSend(long s, T value) throws InterruptedException {
case Continuation c -> {
// a receiver is waiting -> trying to resume
if (c.tryResume(value)) {
setState(DONE, s);
setState(CellState.DONE, s);
return true;
} else {
// cell interrupted -> trying with a new one
return false;
}
}
case INTERRUPTED, BROKEN -> {
case CellState.INTERRUPTED, CellState.BROKEN -> {
// cell interrupted or poisoned -> trying with a new one
return false;
}
Expand Down Expand Up @@ -162,11 +160,11 @@ private Object updateCellReceive(long r) throws InterruptedException {
// not using any payload
var c = new Continuation(null);
if (casState(r, null, c)) {
return c.await(() -> setState(INTERRUPTED, r));
return c.await(() -> setState(CellState.INTERRUPTED, r));
}
// else: CAS unsuccessful, repeat
} else {
if (casState(r, null, BROKEN)) {
if (casState(r, null, CellState.BROKEN)) {
return UpdateCellReceiveResult.RESTART;
}
// else: CAS unsuccessful, repeat
Expand All @@ -175,7 +173,7 @@ private Object updateCellReceive(long r) throws InterruptedException {
case Continuation c -> {
// a sender is waiting -> trying to resume
if (c.tryResume(0)) {
setState(DONE, r);
setState(CellState.DONE, r);
return c.getPayload();
} else {
// cell interrupted -> trying with a new one
Expand All @@ -186,108 +184,119 @@ private Object updateCellReceive(long r) throws InterruptedException {
// an elimination has happened -> finish
return b.value();
}
case INTERRUPTED -> {
case CellState.INTERRUPTED -> {
// cell interrupted -> trying with a new one
return UpdateCellReceiveResult.RESTART;
}
default -> throw new IllegalStateException("Unexpected state: " + state);
}
}
}
}

// possible return values of updateCellReceive: one of the enum constants below, or the received value
// possible return values of updateCellReceive: one of the enum constants below, or the received value

enum UpdateCellReceiveResult {
RESTART
}
private enum UpdateCellReceiveResult {
RESTART
}

// possible states of a cell: one of the enum constants below, Buffered, or Continuation
// possible states of a cell: one of the enum constants below, Buffered, or Continuation

enum CellState {
DONE,
INTERRUPTED,
BROKEN;
}
private enum CellState {
DONE,
INTERRUPTED,
BROKEN;
}

// a java record called Buffered with a single value field; the type should be T
record Buffered(Object value) {}
// a java record called Buffered with a single value field; the type should be T
private record Buffered(Object value) {}

final class Continuation {
/**
* The number of busy-looping iterations before yielding, during {@link Continuation#await(Runnable)}. {@code 0}, if there's a single CPU.
*/
private static final int SPINS = Runtime.getRuntime().availableProcessors() == 1 ? 0 : 10000;
private static final class Continuation {
/**
* The number of busy-looping iterations before yielding, during {@link Continuation#await(Runnable)}. {@code 0}, if there's a single CPU.
*/
private static final int SPINS = Runtime.getRuntime().availableProcessors() == 1 ? 0 : 10000;

private final Thread creatingThread;
private volatile Object data; // set using DATA var handle
private final Thread creatingThread;
private volatile Object data; // set using DATA var handle

private final Object payload;
private final Object payload;

Continuation(Object payload) {
this.payload = payload;
this.creatingThread = Thread.currentThread();
}
Continuation(Object payload) {
this.payload = payload;
this.creatingThread = Thread.currentThread();
}

/**
* Resume the continuation with the given value.
*
* @param value Should not be {@code null}.
* @return {@code true} tf the continuation was resumed successfully. {@code false} if it was interrupted.
*/
boolean tryResume(Object value) {
var result = Continuation.DATA.compareAndSet(this, null, value);
LockSupport.unpark(creatingThread);
return result;
}
/**
* Resume the continuation with the given value.
*
* @param value Should not be {@code null}.
* @return {@code true} tf the continuation was resumed successfully. {@code false} if it was interrupted.
*/
boolean tryResume(Object value) {
var result = Continuation.DATA.compareAndSet(this, null, value);
LockSupport.unpark(creatingThread);
return result;
}

/**
* Await for the continuation to be resumed.
*
* @param onInterrupt
* @return The value with which the continuation was resumed.
*/
Object await(Runnable onInterrupt) throws InterruptedException {
var spinIterations = SPINS;
while (data == null) {
if (spinIterations > 0) {
Thread.onSpinWait();
spinIterations -= 1;
} else {
LockSupport.park();

if (Thread.interrupted()) {
// Continuation.STATE.compareAndSet(this, 0, 2); // TODO if
var e = new InterruptedException();

try {
onInterrupt.run();
} catch (Throwable ee) {
e.addSuppressed(ee);
/**
* Await for the continuation to be resumed.
*
* @param onInterrupt
* @return The value with which the continuation was resumed.
*/
Object await(Runnable onInterrupt) throws InterruptedException {
var spinIterations = SPINS;
while (data == null) {
if (spinIterations > 0) {
Thread.onSpinWait();
spinIterations -= 1;
} else {
LockSupport.park();

if (Thread.interrupted()) {
// potential race with `tryResume`
if (Continuation.DATA.compareAndSet(this, null, ContinuationMarker.INTERRUPTED)) {
var e = new InterruptedException();

try {
onInterrupt.run();
} catch (Throwable ee) {
e.addSuppressed(ee);
}

throw e;
} else {
// another thread already set the data; setting the interrupt status (so that the next blocking
// operation throws), and continuing
Thread.currentThread().interrupt();
}
}

throw e;
}
}
}

return data;
}
return data;
}

Object getPayload() {
return payload;
}
Object getPayload() {
return payload;
}

//
//

private static final VarHandle DATA;
private static final VarHandle DATA;

static {
var l = MethodHandles.lookup();
try {
DATA = l.findVarHandle(Continuation.class, "data", Object.class);
} catch (ReflectiveOperationException e) {
throw new ExceptionInInitializerError(e);
static {
var l = MethodHandles.lookup();
try {
DATA = l.findVarHandle(Continuation.class, "data", Object.class);
} catch (ReflectiveOperationException e) {
throw new ExceptionInInitializerError(e);
}
}
}

// the marker value is used only to mark in the continuation's `data` that interruption won the race with `tryResume`
private enum ContinuationMarker {
INTERRUPTED
}
}
68 changes: 68 additions & 0 deletions core/src/test/java/jox/ChannelInterruptionTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package jox;

import org.junit.jupiter.api.Test;

import static jox.TestUtil.*;
import static org.junit.jupiter.api.Assertions.assertEquals;

public class ChannelInterruptionTest {
@Test
public void testSendReceiveAfterSendInterrupt() throws Exception {
// given
Channel<String> channel = new Channel<>();

// when
scoped(scope -> {
var t1 = forkCancelable(scope, () -> channel.send("x"));
t1.cancel();

forkVoid(scope, () -> channel.send("y"));
var t3 = fork(scope, channel::receive);

// then
assertEquals("y", t3.get());
});
}

@Test
public void testSendReceiveAfterReceiveInterrupt() throws Exception {
// given
Channel<String> channel = new Channel<>();

// when
scoped(scope -> {
var t1 = forkCancelable(scope, channel::receive);
t1.cancel();

forkVoid(scope, () -> channel.send("x"));
var t3 = fork(scope, channel::receive);

// then
assertEquals("x", t3.get());
});
}

@Test
public void testRaceInterruptAndSend() throws Exception {
// when
scoped(scope -> {
for (int i = 0; i < 100; i++) {
// given
Channel<String> channel = new Channel<>();

var t1 = forkCancelable(scope, () -> channel.send("x"));
var t2 = fork(scope, channel::receive);
t1.cancel();

if (t1.cancel() instanceof InterruptedException) {
// the `receive` from t2 has not happened yet
forkVoid(scope, () -> channel.send("y"));
assertEquals("y", t2.get());
} else {
// the `receive` from t2 has already happened
assertEquals("x", t2.get());
}
}
});
}
}
51 changes: 1 addition & 50 deletions core/src/test/java/jox/ChannelTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.util.HashSet;
import java.util.concurrent.*;

import static jox.TestUtil.*;
import static org.junit.jupiter.api.Assertions.assertEquals;

public class ChannelTest {
Expand Down Expand Up @@ -97,54 +98,4 @@ public void performanceTest() throws Exception {
});
}
}

//

private void scoped(ConsumerWithException<StructuredTaskScope<Object>> f) throws InterruptedException, ExecutionException {
try (var scope = new StructuredTaskScope.ShutdownOnFailure()) {
// making sure everything runs in a VT
scope.fork(() -> {
f.accept(scope);
return null;
});
scope.join().throwIfFailed();
}
}

private <T> Future<T> fork(StructuredTaskScope<Object> scope, Callable<T> c) {
var f = new CompletableFuture<T>();
scope.fork(() -> {
try {
f.complete(c.call());
} catch (Exception ex) {
f.completeExceptionally(ex);
}
return null;
});
return f;
}

private Future<Void> forkVoid(StructuredTaskScope<Object> scope, RunnableWithException r) {
return fork(scope, () -> {
r.run();
return null;
});
}

@FunctionalInterface
private interface ConsumerWithException<T> {
void accept(T o) throws Exception;
}

@FunctionalInterface
private interface RunnableWithException {
void run() throws Exception;
}

private void timed(String label, RunnableWithException block) throws Exception {
var start = System.nanoTime();
block.run();
var end = System.nanoTime();
System.out.println(label + " took: " + (end - start) / 1_000_000 + " ms");
}
}
Loading

0 comments on commit 3ca06ae

Please sign in to comment.