Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing interruption behaviour #3183

Merged
merged 13 commits into from
May 12, 2023
4 changes: 3 additions & 1 deletion core/shared/src/main/scala/fs2/Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1471,7 +1471,9 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
}

def endSupply(result: Either[Throwable, Unit]): F2[Unit] =
buffer.update(_.copy(endOfSupply = Some(result))) *> supply.releaseN(Int.MaxValue)
buffer.update(_.copy(endOfSupply = Some(result))) *> supply.releaseN(
Int.MaxValue + outputLong
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, dumb question: why is Int.MaxValue a "magic number" in this context? I would have thought it's effectively maxing out the semaphore, but if it needs + outputLong to work then I feel like it must have more significance?

Copy link
Contributor Author

@Angel-O Angel-O Apr 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Legit question to be fair. Had to think about it again.

Interruption of the upstream fiber (i.e. Outcome.Cancelled) is handled downstream by doing nothing (permits are never released)
So by increasing the supply to Int.MaxValue we are just evening out the negative balance (Int.MaxValue is to account for the worst case scenario: at most the chunkSize parameter will be equal to Int.MaxValue)

    val waitSupply = supply.acquireN(outputLong).guaranteeCase {
        case Outcome.Succeeded(_) => supply.releaseN(outputLong)
        case _                    => F.unit
     }

Now after getting past the "checkpoint" above we are acquiring outputLong permits again

    acq <- F.race(F.sleep(timeout), waitSupply).flatMap {
      case Left(_)  => onTimeout
      case Right(_) => supply.acquireN(outputLong).as(outputLong)
    }

So in order to get past this point we need to release an additional outputLong permits and that allows the stream to be unblocked

EDIT

Interruption of the upstream fiber (i.e. Outcome.Cancelled)

uhm well actually I've just tested it, it is not handled with Outcome.Cancelled...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for that explanation!

Int.MaxValue is to account for the worst case scenario: at most the chunkSize parameter will be equal to Int.MaxValue

So could we just use chunkSize here, instead of Int.MaxValue ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@armanbilge apologies I was wrong, that's not what's happening here. I'm just doing some tests to figure out why we need the additional outputLong

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, if these implementations details are no longer relevant after your rewrite in the other PR, then let's not get too hung up on this one :)

Copy link
Contributor Author

@Angel-O Angel-O Apr 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I think I've figured it out (might be useful for the other implementation actually)

basically the problem is that we need enough supply to cover 2 iterations of the race loop. So if we only increase it by Int.MaxValue the following will happen

  • (current iteration): supply is unblocked
  • (next iteration): supply gets stuck (not enough supply because upstream was interrupted)

if instead we increase it by Int.MaxValue + outputLong

  • (current iteration): supply is unblocked
  • (next iteration): supply is not blocked thanks to the additional outputLong

So since the chunkSize can be as high as Int.MaxValue then the minimum supply to unblock the semaphore should be Int.MaxValue + outputLong

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So since the chunkSize can be as high as Int.MaxValue then the minimum supply to unblock the semaphore should be Int.MaxValue + outputLong

Key word being "can". Wouldn't chunkSize + outputLong be sufficient?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that should work. The test still passes, I'll change it to outputLong * 2 since chunkSize == outputLong


def endDemand(result: Either[Throwable, Unit]): F2[Unit] =
buffer.update(_.copy(endOfDemand = Some(result))) *> demand.releaseN(Int.MaxValue)
Expand Down
69 changes: 68 additions & 1 deletion core/shared/src/test/scala/fs2/StreamCombinatorsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ package fs2

import cats.effect.kernel.Deferred
import cats.effect.kernel.Ref
import cats.effect.std.{Semaphore, Queue}
import cats.effect.std.{Queue, Semaphore}
import cats.effect.testkit.TestControl
import cats.effect.{IO, SyncIO}
import cats.syntax.all._
Expand All @@ -34,6 +34,7 @@ import org.scalacheck.Prop.forAll

import scala.concurrent.duration._
import scala.concurrent.TimeoutException
import scala.util.control.NoStackTrace

class StreamCombinatorsSuite extends Fs2Suite {
override def munitIOTimeout = 1.minute
Expand Down Expand Up @@ -834,6 +835,72 @@ class StreamCombinatorsSuite extends Fs2Suite {
)
.assertEquals(0.millis)
}

test("upstream failures are propagated downstream") {
TestControl.executeEmbed {
case object SevenNotAllowed extends NoStackTrace

val source = Stream
.iterate(0)(_ + 1)
.covary[IO]
.evalTap(n => IO.raiseError(SevenNotAllowed).whenA(n == 7))

val downstream = source.groupWithin(100, 2.seconds)

downstream.intercept[SevenNotAllowed.type]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make an assertion here about what the downstream has / has not received before the error?

}
}

test(
"upstream interruption causes immediate downstream termination with all elements being emitted"
) {

val sourceTimeout = 5.5.seconds
val downstreamTimeout = sourceTimeout + 2.seconds

TestControl
.executeEmbed {
val source: Stream[IO, Int] =
Stream
.iterate(0)(_ + 1)
.covary[IO]
.meteredStartImmediately(1.second)
.interruptAfter(sourceTimeout)

// large chunkSize and timeout (no emissions expected in the window
// specified, unless source ends, due to interruption or
// natural termination (i.e runs out of elements)
val downstream: Stream[IO, Chunk[Int]] =
source.groupWithin(Int.MaxValue, 1.day)

downstream.compile.lastOrError
.timeout(downstreamTimeout)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this timeout necessary? Since its an executeEmbed test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's because if the test fails we get a slightly better error message

java.util.concurrent.TimeoutException: 7500 milliseconds which can be easily associated to downstreamTimeout

otherwise we get this value on the diff which looks a bit random

_1 = 86405500000000 nanoseconds,

but I'm happy to remove it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, that is a nicer error :) thanks!

.map(_.toList)
.timed
}
.assertEquals(
// downstream ended immediately (i.e timeLapsed = sourceTimeout)
// emitting whatever was accumulated at the time of interruption
(sourceTimeout, List(0, 1, 2, 3, 4, 5))
)
}

test("stress test: all elements are processed") {
val rangeLength = 10000

Stream
.eval(Ref.of[IO, Int](0))
.flatMap { counter =>
Stream
.range(0, rangeLength)
.covary[IO]
.groupWithin(4096, 100.micros)
.evalTap(ch => counter.update(_ + ch.size)) *> Stream.eval(counter.get)
}
.compile
.lastOrError
.assertEquals(rangeLength)
}
}

property("head")(forAll((s: Stream[Pure, Int]) => assertEquals(s.head.toList, s.toList.take(1))))
Expand Down