Skip to content

Commit

Permalink
Merge pull request #3183 from Angel-O/fixing-interruption-behaviour
Browse files Browse the repository at this point in the history
Fixing interruption behaviour
  • Loading branch information
armanbilge authored May 12, 2023
2 parents deb2e2f + 89889fa commit bfbb489
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 2 deletions.
6 changes: 5 additions & 1 deletion core/shared/src/main/scala/fs2/Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1477,7 +1477,11 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
}

def endSupply(result: Either[Throwable, Unit]): F2[Unit] =
buffer.update(_.copy(endOfSupply = Some(result))) *> supply.releaseN(Int.MaxValue)
buffer.update(_.copy(endOfSupply = Some(result))) *> supply.releaseN(
// enough supply for 2 iterations of the race loop in case of upstream
// interruption: so that downstream can terminate immediately
outputLong * 2
)

def endDemand(result: Either[Throwable, Unit]): F2[Unit] =
buffer.update(_.copy(endOfDemand = Some(result))) *> demand.releaseN(Int.MaxValue)
Expand Down
71 changes: 70 additions & 1 deletion core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ package fs2

import cats.effect.kernel.Deferred
import cats.effect.kernel.Ref
import cats.effect.std.{Semaphore, Queue}
import cats.effect.std.{Queue, Semaphore}
import cats.effect.testkit.TestControl
import cats.effect.{IO, SyncIO}
import cats.syntax.all._
Expand All @@ -34,6 +34,7 @@ import org.scalacheck.Prop.forAll

import scala.concurrent.duration._
import scala.concurrent.TimeoutException
import scala.util.control.NoStackTrace

class StreamCombinatorsSuite extends Fs2Suite {
override def munitIOTimeout = 1.minute
Expand Down Expand Up @@ -834,6 +835,74 @@ class StreamCombinatorsSuite extends Fs2Suite {
)
.assertEquals(0.millis)
}

test("upstream failures are propagated downstream") {
TestControl.executeEmbed {
case object SevenNotAllowed extends NoStackTrace

val source = Stream
.iterate(0)(_ + 1)
.covary[IO]
.evalTap(n => IO.raiseError(SevenNotAllowed).whenA(n == 7))

val downstream = source.groupWithin(100, 2.seconds).map(_.toList)

val expected = List(List(1, 2, 3, 4, 5, 6))

downstream.assertEmits(expected).intercept[SevenNotAllowed.type]
}
}

test(
"upstream interruption causes immediate downstream termination with all elements being emitted"
) {

val sourceTimeout = 5.5.seconds
val downstreamTimeout = sourceTimeout + 2.seconds

TestControl
.executeEmbed {
val source: Stream[IO, Int] =
Stream
.iterate(0)(_ + 1)
.covary[IO]
.meteredStartImmediately(1.second)
.interruptAfter(sourceTimeout)

// large chunkSize and timeout (no emissions expected in the window
// specified, unless source ends, due to interruption or
// natural termination (i.e runs out of elements)
val downstream: Stream[IO, Chunk[Int]] =
source.groupWithin(Int.MaxValue, 1.day)

downstream.compile.lastOrError
.timeout(downstreamTimeout)
.map(_.toList)
.timed
}
.assertEquals(
// downstream ended immediately (i.e timeLapsed = sourceTimeout)
// emitting whatever was accumulated at the time of interruption
(sourceTimeout, List(0, 1, 2, 3, 4, 5))
)
}

test("stress test: all elements are processed") {
val rangeLength = 10000

Stream
.eval(Ref.of[IO, Int](0))
.flatMap { counter =>
Stream
.range(0, rangeLength)
.covary[IO]
.groupWithin(4096, 100.micros)
.evalTap(ch => counter.update(_ + ch.size)) *> Stream.eval(counter.get)
}
.compile
.lastOrError
.assertEquals(rangeLength)
}
}

property("head")(forAll((s: Stream[Pure, Int]) => assertEquals(s.head.toList, s.toList.take(1))))
Expand Down

0 comments on commit bfbb489

Please sign in to comment.