Skip to content

Commit

Permalink
Add flattenPar, fix type issues, fix from method working only once
Browse files Browse the repository at this point in the history
  • Loading branch information
emil-bar committed Jan 10, 2025
1 parent 3678a8a commit a2fcac9
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 18 deletions.
93 changes: 82 additions & 11 deletions flows/src/main/java/com/softwaremill/jox/flows/Flow.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.time.Duration;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
Expand All @@ -55,6 +53,7 @@
import com.softwaremill.jox.structured.Fork;
import com.softwaremill.jox.structured.Scope;
import com.softwaremill.jox.structured.Scopes;
import com.softwaremill.jox.structured.ThrowingRunnable;
import com.softwaremill.jox.structured.UnsupervisedScope;

/**
Expand Down Expand Up @@ -551,7 +550,7 @@ public <U> Flow<Map.Entry<T, U>> zip(Flow<U> other) {
} else {
// noinspection unchecked
U u = (U) result2;
emit.apply(new AbstractMap.SimpleEntry<>(t, u));
emit.apply(Map.entry(t, u));
}
}
}
Expand Down Expand Up @@ -625,14 +624,80 @@ public Flow<Map.Entry<T, Long>> zipWithIndex() {
}

@SafeVarargs
public final <U> Flow<U> flatten(T... args) {
if (!Flow.class.equals(args.getClass().getComponentType())) {
public final T flatten(T... args) {
if (!Flow.class.equals(getTClass(args))) {
throw new IllegalArgumentException("requirement failed: flatten can be called on Flow containing Flows");
}
//noinspection unchecked
return this.flatMap(t -> ((Flow<U>) t));
//noinspection unchecked,rawtypes
return (T) this.flatMap(t -> (Flow) t);
}

@SuppressWarnings("unchecked")
@SafeVarargs
public final <U> T flattenPar(int parallelism, T... args) {
if (!Flow.class.equals(getTClass(args))) {
throw new IllegalArgumentException("requirement failed: flattenPar can be called on Flow containing Flows");
}
return (T) Flows.usingEmit(emit -> {
class Nested {
final Flow<U> child;
Nested(Flow<U> child) {
this.child = child;
}
}
final class ChildDone {}

unsupervised(scope -> {
Channel<U> childOutputChannel = Channel.withScopedBufferSize();
Channel<ChildDone> childDoneChannel = Channel.withScopedBufferSize();

// When an error occurs in the parent, propagating it also to `childOutputChannel`, from which we always
// `select` in the main loop. That way, even if max parallelism is reached, errors in the parent will
// be discovered without delay.
//noinspection unchecked
Source<Nested> parentChannel = map(t -> new Nested((Flow<U>) t))
.onError(childOutputChannel::error)
.runToChannel(scope);

int runningChannelCount = 1; // parent is running
boolean parentDone = false;

while (runningChannelCount > 0) {
assert runningChannelCount <= parallelism + 1;

Object result;
if (runningChannelCount == parallelism + 1 || parentDone) {
result = selectOrClosed(childOutputChannel.receiveClause(), childDoneChannel.receiveClause());
} else {
result = selectOrClosed(childOutputChannel.receiveClause(), childDoneChannel.receiveClause(), parentChannel.receiveClause());
}

// Only `parentChannel` might be done, child completion is signalled via `childDoneChannel`.
if (result instanceof ChannelDone) {
parentDone = parentChannel.isClosedForReceive();
assert parentDone;
runningChannelCount--;
} else if (result instanceof ChannelError e) {
throw e.toException();
} else if (ChildDone.class.isInstance(result)) {
runningChannelCount--;
} else if (Nested.class.isInstance(result)) {
//noinspection unchecked
Nested t = (Nested) result;
scope.forkUnsupervised(() -> {
t.child.onDone(() -> childDoneChannel.send(new ChildDone()))
.runPipeToSink(childOutputChannel, false);
return null;
});
runningChannelCount++;
} else if (result != null) {
emit.apply(result);
}
}
return null;
});
});
}

/**
* Applies the given `mappingFunction` to each element emitted by this flow, in sequence.
Expand Down Expand Up @@ -910,7 +975,7 @@ public Flow<T> onComplete(Runnable f) {
/**
* Runs `f` after the flow completes successfully, that is when all elements are emitted.
*/
public Flow<T> onDone(Runnable f) {
public Flow<T> onDone(ThrowingRunnable f) {
return usingEmit(emit -> {
last.run(emit);
f.run();
Expand Down Expand Up @@ -1416,7 +1481,6 @@ public Publisher<T> toPublisher(Scope scope) {
// endregion

// region ByteFlow

public interface ByteChunkMapper<T> extends Function<T, ByteChunk> {}
public interface ByteArrayMapper<T> extends Function<T, byte[]> {}

Expand All @@ -1427,8 +1491,7 @@ public interface ByteArrayMapper<T> extends Function<T, byte[]> {}
*/
@SafeVarargs
public final ByteFlow toByteFlow(T... args) {
//noinspection unchecked
return new ByteFlow(last, (Class<T>) args.getClass().getComponentType());
return new ByteFlow(last, getTClass(args));
}

/**
Expand Down Expand Up @@ -1651,6 +1714,14 @@ private void runLastToChannelAsync(UnsupervisedScope scope, Channel<T> channel)
});
}

@SuppressWarnings("unchecked")
private static <T> Class<T> getTClass(T[] args) {
if (args.length > 0) {
throw new IllegalArgumentException("Please do not pass any arguments for this method. Java will detect the type automatically.");
}
return (Class<T>) args.getClass().getComponentType();
}

private static class BreakException extends RuntimeException {
}
}
39 changes: 35 additions & 4 deletions flows/src/main/java/com/softwaremill/jox/flows/Flows.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import java.nio.file.StandardOpenOption;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -60,29 +59,48 @@ public static <T> Flow<T> fromSource(Source<T> source) {

/**
* Creates a flow from the given `iterable`. Each element of the iterable is emitted in order.
* The flow can be run multiple times.
*/
public static <T> Flow<T> fromIterable(Iterable<T> iterable) {
return fromIterator(iterable.iterator());
return usingEmit(emit -> {
for (T t : iterable) {
emit.apply(t);
}
});
}

/**
* Creates a flow from the given values. Each value is emitted in order.
* Flow can be run multiple times.
*/
@SafeVarargs
public static <T> Flow<T> fromValues(T... ts) {
return fromIterator(Arrays.asList(ts).iterator());
return usingEmit(emit -> {
for (T t : ts) {
emit.apply(t);
}
});
}

/**
* Creates a ByteFlow from given {@link ByteChunk}s. Each ByteChunk is emitted in order.
* Flow can be run multiple times.
*/
public static Flow.ByteFlow fromByteChunks(ByteChunk... chunks) {
return fromIterator(Arrays.asList(chunks).iterator()).toByteFlow();
return fromValues(chunks).toByteFlow();
}

/**
* Creates a ByteFlow from given byte[]. Each byte[] is emitted in order.
* Flow can be run multiple times.
*/
public static Flow.ByteFlow fromByteArrays(byte[]... byteArrays) {
return fromValues(byteArrays).toByteFlow();
}

/**
* Creates a flow from the given (lazily evaluated) `iterator`. Each element of the iterator is emitted in order.
* The flow can be run only once, as the iterator is consumed. If you need to run the flow multiple times, use {@link #fromIterator(Supplier)}.
*/
public static <T> Flow<T> fromIterator(Iterator<T> it) {
return usingEmit(emit -> {
Expand All @@ -92,6 +110,19 @@ public static <T> Flow<T> fromIterator(Iterator<T> it) {
});
}

/**
* Creates a flow from the given (lazily evaluated) `iterator`. Each element of the iterator is emitted in order.
* The flow can be run multiple times, as the `iteratorSupplier` is called each time the flow is run.
*/
public static <T> Flow<T> fromIterator(Supplier<Iterator<T>> iteratorSupplier) {
return usingEmit(emit -> {
var it = iteratorSupplier.get();
while (it.hasNext()) {
emit.apply(it.next());
}
});
}

/**
* Creates a flow from the given fork. The flow will emit up to one element, or complete by throwing an exception if the fork fails.
*/
Expand Down
103 changes: 101 additions & 2 deletions flows/src/test/java/com/softwaremill/jox/flows/FlowFlattenTest.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
package com.softwaremill.jox.flows;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import java.time.Duration;
import java.util.ArrayList;
import java.util.List;

import org.hamcrest.Matchers;
import org.junit.jupiter.api.Test;

public class FlowFlattenTest {
Expand All @@ -15,16 +21,109 @@ void flattenTest() throws Exception {
Flow<Flow<Integer>> flow = Flows.fromValues(Flows.fromValues(1, 2), Flows.fromValues(5, 9));

// when
List<Integer> integers = flow.<Integer>flatten().runToList();
List<Integer> integers = flow.flatten().runToList();

// then
assertEquals(List.of(1, 2, 5, 9), integers);
}

@Test
void shouldThrowWhenCalledOnFlowNotContainingFlows() {
void flatten_shouldThrowWhenCalledOnFlowNotContainingFlows() {
// when & then
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> Flows.fromValues(3, 3).flatten());
assertEquals("requirement failed: flatten can be called on Flow containing Flows", exception.getMessage());
}

@Test
void flattenPar_shouldThrowWhenCalledOnFlowNotContainingFlows() {
// when & then
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
() -> Flows.fromValues(3, 3).flattenPar(1));
assertEquals("requirement failed: flattenPar can be called on Flow containing Flows", exception.getMessage());
}

@Test
void shouldPipeAllElementsOfTheChildFlowsIntoTheOutputFlow() throws Exception {
// given
var flow = Flows.fromValues(
Flows.fromValues(10),
Flows.fromValues(20, 30),
Flows.fromValues(40, 50, 60)
);

// when & then
List<Integer> actual = flow.flattenPar(10).runToList();
assertThat(actual, containsInAnyOrder(10, 20, 30, 40, 50, 60));
}

@Test
void shouldHandleEmptyFlow() throws Exception {
// given
var flow = Flows.<Flow<?>>empty();

// when & then
assertThat(flow.flattenPar(10).runToList(), Matchers.empty());
}

@Test
void shouldHandleSingletonFlow() throws Exception {
// given
var flow = Flows.fromValues(Flows.fromValues(10));

// when & then
List<Integer> objects = flow.flattenPar(10).runToList();
assertThat(objects, contains(10));
}

@Test
void shouldNotFlattenNestedFlows() throws Exception {
// given
var flow = Flows.fromValues(Flows.fromValues(Flows.fromValues(10)));

// when
List<Flow<Integer>> flows = flow.flattenPar(10).runToList();

// then
List<Integer> result = new ArrayList<>();
for (Flow<Integer> f : flows) {
List<Integer> integers = f.runToList();
result.addAll(integers);
}
assertThat(result, contains(10));
}

@Test
void shouldHandleSubsequentFlattenCalls() throws Exception {
// given
var flow = Flows.fromValues(Flows.fromValues(Flows.fromValues(10), Flows.fromValues(20)));

// when & then
var result = flow.flattenPar(10)
.runToList().stream().flatMap(f -> {
try {
return f.runToList().stream();
} catch (Exception e) {
throw new RuntimeException(e);
}
})
.toList();

assertThat(result, containsInAnyOrder(10, 20));
}

@Test
void shouldRunAtMostParallelismChildFlows() throws Exception {
// given
var flow = Flows.fromValues(
Flows.timeout(Duration.ofMillis(200)).concat(Flows.fromValues(10)),
Flows.timeout(Duration.ofMillis(100)).concat(Flows.fromValues(20, 30)),
Flows.fromValues(40, 50, 60)
);

// when & then
// only one flow can run at a time
assertThat(flow.flattenPar(1).runToList(), contains(10, 20, 30, 40, 50, 60));
// when parallelism is increased, all flows are run concurrently
assertThat(flow.flattenPar(3).runToList(), contains(40, 50, 60, 20, 30, 10));
}
}
18 changes: 18 additions & 0 deletions flows/src/test/java/com/softwaremill/jox/flows/FlowsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,24 @@ void shouldUnfoldFunction() throws Exception {
assertEquals(List.of(0, 1, 2), c.runToList());
}

@Test
void shouldRunFromIteratorOnlyOnce() throws Exception {
Flow<Integer> flow = Flows.fromIterator(List.of(1, 2, 3).iterator());

assertEquals(List.of(1, 2, 3), flow.runToList()); // first run traverses iterator
assertEquals(Collections.emptyList(), flow.runToList()); // second run is empty, as iterator is exhausted
}

@Test
void shouldRunFromIteratorSupplierMultipleTimes() throws Exception {
List<Integer> source = List.of(1, 2, 3);
var flow = Flows.fromIterator(source::iterator);

for (int i = 0; i < 5; i++) {
assertEquals(List.of(1, 2, 3), flow.runToList()); // each run gets new iterator, and is able to traverse it
}
}

private List<String> toStrings(Flow<ByteChunk> source) throws Exception {
return source.runToList().stream()
.map(chunk -> chunk.convertToString(StandardCharsets.UTF_8))
Expand Down
Loading

0 comments on commit a2fcac9

Please sign in to comment.