Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing interruption behaviour #3183

Merged
merged 13 commits into from
May 12, 2023
6 changes: 5 additions & 1 deletion core/shared/src/main/scala/fs2/Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1477,7 +1477,11 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
}

def endSupply(result: Either[Throwable, Unit]): F2[Unit] =
buffer.update(_.copy(endOfSupply = Some(result))) *> supply.releaseN(Int.MaxValue)
buffer.update(_.copy(endOfSupply = Some(result))) *> supply.releaseN(
// enough supply for 2 iterations of the race loop in case of upstream
// interruption: so that downstream can terminate immediately
outputLong * 2
)

def endDemand(result: Either[Throwable, Unit]): F2[Unit] =
buffer.update(_.copy(endOfDemand = Some(result))) *> demand.releaseN(Int.MaxValue)
Expand Down
71 changes: 70 additions & 1 deletion core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ package fs2

import cats.effect.kernel.Deferred
import cats.effect.kernel.Ref
import cats.effect.std.{Semaphore, Queue}
import cats.effect.std.{Queue, Semaphore}
import cats.effect.testkit.TestControl
import cats.effect.{IO, SyncIO}
import cats.syntax.all._
Expand All @@ -34,6 +34,7 @@ import org.scalacheck.Prop.forAll

import scala.concurrent.duration._
import scala.concurrent.TimeoutException
import scala.util.control.NoStackTrace

class StreamCombinatorsSuite extends Fs2Suite {
override def munitIOTimeout = 1.minute
Expand Down Expand Up @@ -834,6 +835,74 @@ class StreamCombinatorsSuite extends Fs2Suite {
)
.assertEquals(0.millis)
}

test("upstream failures are propagated downstream") {
TestControl.executeEmbed {
case object SevenNotAllowed extends NoStackTrace

val source = Stream
.iterate(0)(_ + 1)
.covary[IO]
.evalTap(n => IO.raiseError(SevenNotAllowed).whenA(n == 7))

val downstream = source.groupWithin(100, 2.seconds).map(_.toList)

val expected = List(List(1, 2, 3, 4, 5, 6))

downstream.assertEmits(expected).intercept[SevenNotAllowed.type]
}
}

test(
"upstream interruption causes immediate downstream termination with all elements being emitted"
) {

val sourceTimeout = 5.5.seconds
val downstreamTimeout = sourceTimeout + 2.seconds

TestControl
.executeEmbed {
val source: Stream[IO, Int] =
Stream
.iterate(0)(_ + 1)
.covary[IO]
.meteredStartImmediately(1.second)
.interruptAfter(sourceTimeout)

// large chunkSize and timeout (no emissions expected in the window
// specified, unless source ends, due to interruption or
// natural termination (i.e runs out of elements)
val downstream: Stream[IO, Chunk[Int]] =
source.groupWithin(Int.MaxValue, 1.day)

downstream.compile.lastOrError
.timeout(downstreamTimeout)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this timeout necessary? Since its an executeEmbed test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's because if the test fails we get a slightly better error message

java.util.concurrent.TimeoutException: 7500 milliseconds which can be easily associated to downstreamTimeout

otherwise we get this value on the diff which looks a bit random

_1 = 86405500000000 nanoseconds,

but I'm happy to remove it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, that is a nicer error :) thanks!

.map(_.toList)
.timed
}
.assertEquals(
// downstream ended immediately (i.e timeLapsed = sourceTimeout)
// emitting whatever was accumulated at the time of interruption
(sourceTimeout, List(0, 1, 2, 3, 4, 5))
)
}

test("stress test: all elements are processed") {
val rangeLength = 10000

Stream
.eval(Ref.of[IO, Int](0))
.flatMap { counter =>
Stream
.range(0, rangeLength)
.covary[IO]
.groupWithin(4096, 100.micros)
.evalTap(ch => counter.update(_ + ch.size)) *> Stream.eval(counter.get)
}
.compile
.lastOrError
.assertEquals(rangeLength)
}
}

property("head")(forAll((s: Stream[Pure, Int]) => assertEquals(s.head.toList, s.toList.take(1))))
Expand Down