Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closing channels #9

Merged
merged 5 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
408 changes: 359 additions & 49 deletions core/src/main/java/jox/Channel.java

Large diffs are not rendered by default.

19 changes: 19 additions & 0 deletions core/src/main/java/jox/ChannelClosed.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package jox;

public sealed interface ChannelClosed permits ChannelClosed.ChannelDone, ChannelClosed.ChannelError {
ChannelClosedException toException();

record ChannelDone() implements ChannelClosed {
@Override
public ChannelClosedException toException() {
return new ChannelClosedException.ChannelDoneException();
}
}

record ChannelError(Throwable cause) implements ChannelClosed {
@Override
public ChannelClosedException toException() {
return new ChannelClosedException.ChannelErrorException(cause);
}
}
}
24 changes: 24 additions & 0 deletions core/src/main/java/jox/ChannelClosedException.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package jox;

public sealed class ChannelClosedException extends RuntimeException permits ChannelClosedException.ChannelDoneException, ChannelClosedException.ChannelErrorException {
public ChannelClosedException() {
}

public ChannelClosedException(Throwable cause) {
super(cause);
}

public static final class ChannelDoneException extends ChannelClosedException {
public ChannelDoneException() {
}
}

public static final class ChannelErrorException extends ChannelClosedException {
public ChannelErrorException() {
}

public ChannelErrorException(Throwable cause) {
super(cause);
}
}
}
108 changes: 78 additions & 30 deletions core/src/main/java/jox/Segment.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,26 @@ final class Segment {
private static final int POINTERS_SHIFT = 12;
static final Segment NULL_SEGMENT = new Segment(-1, null, 0, false);

/**
* Used in {@code next} to indicate that the segment is closed.
*/
private enum State {
CLOSED
}

private final long id;
private final AtomicReferenceArray<Object> data = new AtomicReferenceArray<>(SEGMENT_SIZE);
private final AtomicReference<Segment> next = new AtomicReference<>(null);
/**
* Possible values: {@code Segment} or {@code State.CLOSED} (union type).
*/
private final AtomicReference<Object> next = new AtomicReference<>(null);
private final AtomicReference<Segment> prev;
/**
* A single counter that can be inspected & modified atomically, which includes:
* - the number of incoming pointers (shifted by {@link Segment#POINTERS_SHIFT} to the left)
* - the number of cells, which haven't been processed by {@code Channel.expandBuffer} yet (shifted by {@link Segment#PROCESSED_SHIFT} to the left)
* - the number of cells, which haven't been interrupted yet (in the first 6 bits)
* - the number of cells, which haven't been processed by {@code Channel.expandBuffer} or closed yet
* (shifted by {@link Segment#PROCESSED_SHIFT} to the left)
* - the number of cells, which haven't been interrupted or closed yet (in the first 6 bits)
* When this reaches 0, the segment is logically removed.
*/
private final AtomicInteger pointers_notProcessed_notInterrupted;
Expand All @@ -42,28 +52,21 @@ long getId() {
return id;
}

Segment getPrev() {
return prev.get();
}

void setPrev(Segment newPrev) {
prev.set(newPrev);
}

void cleanPrev() {
prev.set(null);
}

Segment getNext() {
return next.get();
var s = next.get();
return s == State.CLOSED ? null : (Segment) s;
}

boolean casNext(Segment expected, Segment setTo) {
return next.compareAndSet(expected, setTo);
Segment getPrev() {
return prev.get();
}

void setNext(Segment newNext) {
next.set(newNext);
private boolean setNextIfNull(Segment setTo) {
return next.compareAndSet(null, setTo);
}

Object getCell(int index) {
Expand All @@ -79,7 +82,7 @@ boolean casCell(int index, Object expected, Object newValue) {
}

private boolean isTail() {
return next.get() == null;
return getNext() == null;
}

/**
Expand Down Expand Up @@ -124,10 +127,12 @@ void cellInterruptedReceiver() {
}

/**
* Notify the segment that a `send` has been interrupted in the cell. Also marks the cell as processed.
* Notify the segment that a `send` has been interrupted in the cell, or that the cell has been closed. Also marks
* the cell as processed.
* <p>
* Should be called at most once for each cell. Removes the segment, if it becomes logically removed.
*/
void cellInterruptedSender() {
void cellInterruptedSender_orClosed() {
if (countProcessed) {
// decrementing both counters in a single operation
if (pointers_notProcessed_notInterrupted.addAndGet(-(1 << PROCESSED_SHIFT) - 1) == 0) remove();
Expand All @@ -138,7 +143,8 @@ void cellInterruptedSender() {

/**
* Notify the segment that a cell has been processed by {@code Channel.expandBuffer}. Should not be called
* in the cell has an interrupted sender.
* if the cell has an interrupted sender.
* <p>
* Should be called at most once for each cell. Removes the segment, if it becomes logically removed.
*/
void cellProcessed_notInterruptedSender() {
Expand All @@ -160,7 +166,7 @@ void remove() {

// link next and prev
_next.prev.updateAndGet(p -> p == null ? null : _prev);
if (_prev != null) _prev.setNext(_next);
if (_prev != null) _prev.next.set(_next);

// double-checking if _prev & _next are still not removed
if (_next.isRemoved() && !_next.isTail()) continue;
Expand All @@ -171,6 +177,26 @@ void remove() {
}
}

/**
* Closes the segment chain - sets the {@code next} pointer of the last segment to {@code State.CLOSED}, and returns the last segment.
*/
Segment close() {
var s = this;
while (true) {
var n = s.next.get();
if (n == null) { // this is the tail segment
if (s.next.compareAndSet(null, State.CLOSED)) {
return s;
}
// else: try again
} else if (n == State.CLOSED) {
return s;
} else {
s = (Segment) n;
}
}
}

private Segment aliveSegmentLeft() {
var s = prev.get();
while (s != null && s.isRemoved()) {
Expand All @@ -179,23 +205,31 @@ private Segment aliveSegmentLeft() {
return s;
}

/**
* Should only be called, if this is not the tail segment.
*/
private Segment aliveSegmentRight() {
var s = next.get();
while (s.isRemoved() && !s.isTail()) {
s = s.next.get();
var n = (Segment) next.get(); // this is not the tail, so there's a next segment
while (n.isRemoved() && !n.isTail()) {
n = (Segment) n.next.get(); // again, not tail
}
return s;
return n;
}

//

/**
* Finds or creates a non-removed segment with an id at least {@code id}, starting from {@code start}, and updates
* the {@code ref} reference to it.
*
* @return The found segment, or {@code null} if the segment chain is closed.
*/
static Segment findAndMoveForward(AtomicReference<Segment> ref, Segment start, long id, boolean countProcessed) {
while (true) {
var segment = findSegment(start, id, countProcessed);
if (segment == null) {
return null;
}
if (moveForward(ref, segment)) {
return segment;
}
Expand All @@ -205,21 +239,29 @@ static Segment findAndMoveForward(AtomicReference<Segment> ref, Segment start, l
/**
* Finds a non-removed segment with an id at least {@code id}, starting from {@code start}. New segments are created
* if needed; this might prompt physical removal of the previously-tail segment.
*
* @return The found segment, or {@code null} if the segment chain is closed.
*/
private static Segment findSegment(Segment start, long id, boolean countProcessed) {
var current = start;
while (current.getId() < id || current.isRemoved()) {
// create a new segment if needed
if (current.getNext() == null) {
var n = current.next.get();
if (n == State.CLOSED) {
// segment chain is closed, so we can't create a new segment
return null;
} else if (n == null) {
// create a new segment if needed
var newSegment = new Segment(current.getId() + 1, current, 0, countProcessed);
if (current.casNext(null, newSegment)) {
if (current.setNextIfNull(newSegment)) {
if (current.isRemoved()) {
// the current segment was a tail segment, so if it was logically removed, we need to remove it physically
current.remove();
}
}
// else: try again with current
} else {
current = (Segment) n;
}
current = current.getNext();
}
return current;
}
Expand Down Expand Up @@ -272,11 +314,17 @@ public String toString() {

return "Segment{" +
"id=" + id +
", next=" + (n == null ? "null" : n.id) +
", next=" + (n == null ? "null" : (n == State.CLOSED ? "closed" : ((Segment) n).id)) +
", prev=" + (p == null ? "null" : p.id) +
", pointers=" + pointers +
", notProcessed=" + notProcessed +
", notInterrupted=" + notInterrupted +
'}';
}

// for tests

void setNext(Segment newNext) {
next.set(newNext);
}
}
43 changes: 39 additions & 4 deletions core/src/test/java/jox/ChannelBufferedTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;

import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutionException;

import static jox.TestUtil.forkVoid;
import static jox.TestUtil.scoped;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;

/**
* Tests which always use buffered channels.
Expand Down Expand Up @@ -52,7 +50,6 @@ void testSimpleSendReceiveBuffer2() throws InterruptedException {
void testBufferCapacityStaysTheSameAfterSendsReceives() throws ExecutionException, InterruptedException {
// given
Channel<Integer> channel = new Channel<>(2);
var trail = new ConcurrentLinkedQueue<String>();

// when
scoped(scope -> {
Expand All @@ -77,4 +74,42 @@ void testBufferCapacityStaysTheSameAfterSendsReceives() throws ExecutionExceptio
channel.send(6); // should not block
});
}

@Test
@Timeout(1)
void shouldReceiveFromAChannelUntilDone() throws InterruptedException {
// given
Channel<Integer> c = new Channel<>(3);
c.send(1);
c.send(2);
c.done();

// when
var r1 = c.receiveSafe();
var r2 = c.receiveSafe();
var r3 = c.receiveSafe();

// then
assertEquals(1, r1);
assertEquals(2, r2);
assertInstanceOf(ChannelClosed.class, r3);
}

@Test
@Timeout(1)
void shouldNotReceiveFromAChannelInCaseOfAnError() throws InterruptedException {
// given
Channel<Integer> c = new Channel<>(3);
c.send(1);
c.send(2);
c.error(new RuntimeException());

// when
var r1 = c.receiveSafe();
var r2 = c.receiveSafe();

// then
assertInstanceOf(ChannelClosed.ChannelError.class, r1);
assertInstanceOf(ChannelClosed.ChannelError.class, r2);
}
}
Loading