Skip to content

Commit

Permalink
[SPARK-50853] - Kind of ugly test case to show claims about not closi…
Browse files Browse the repository at this point in the history
…ng temp writable channel
  • Loading branch information
Michael Chen committed Jan 21, 2025
1 parent db5c0b1 commit a93d0f3
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,16 @@ public String path() {
return file.getAbsolutePath();
}

private class SimpleDownloadWritableChannel implements DownloadFileWritableChannel {
// public for testing
public class SimpleDownloadWritableChannel implements DownloadFileWritableChannel {

private final WritableByteChannel channel;

// for testing
public SimpleDownloadWritableChannel(WritableByteChannel channel) throws FileNotFoundException {
this.channel = channel;
}

SimpleDownloadWritableChannel() throws FileNotFoundException {
channel = Channels.newChannel(new FileOutputStream(file));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.network.netty

import java.io.IOException
import java.nio.channels.WritableByteChannel

import scala.concurrent.{ExecutionContext, Future}
import scala.reflect.ClassTag
Expand All @@ -31,8 +32,9 @@ import org.scalatest.matchers.should.Matchers._

import org.apache.spark.{ExecutorDeadException, SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.network.BlockDataManager
import org.apache.spark.network.client.{TransportClient, TransportClientFactory}
import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager}
import org.apache.spark.network.client.{RpcResponseCallback, StreamCallback, TransportClient, TransportClientFactory}
import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, SimpleDownloadFile}
import org.apache.spark.network.shuffle.protocol.StreamHandle
import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcTimeout}
import org.apache.spark.serializer.{JavaSerializer, SerializerManager}

Expand Down Expand Up @@ -130,6 +132,62 @@ class NettyBlockTransferServiceSuite
assert(hitExecutorDeadException)
}

test("SPARK-50853 - example of simple download file writable channel not being closed") {
implicit val executionContext = ExecutionContext.global
val port = 17634 + Random.nextInt(10000)
logInfo("random port for test: " + port)

val driverEndpointRef = new RpcEndpointRef(new SparkConf()) {
override def address: RpcAddress = null
override def name: String = "test"
override def send(message: Any): Unit = {}
override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
Future{false.asInstanceOf[T]}
}
}

val clientFactory = mock(classOf[TransportClientFactory])
val client = mock(classOf[TransportClient])
when(client.sendRpc(any(), any())).thenAnswer( invocation => {
invocation.getArgument[RpcResponseCallback](1).onSuccess(new StreamHandle(1, 1).toByteBuffer)
})
when(client.stream(any(), any())).thenAnswer(invocation => {
invocation.getArgument[StreamCallback](1).onComplete("1")
})
var createClientCount = 0
when(clientFactory.createClient(any(), any(), any())).thenAnswer(_ => {
createClientCount += 1
client
})

val listener = mock(classOf[BlockFetchingListener])
service0 = createService(port, driverEndpointRef)
val clientFactoryField = service0.getClass
.getSuperclass.getSuperclass.getDeclaredField("clientFactory")
clientFactoryField.setAccessible(true)
clientFactoryField.set(service0, clientFactory)

val downloadManager = mock(classOf[DownloadFileManager])
val simpleDownloadFile = mock(classOf[SimpleDownloadFile])
val channel = mock(classOf[WritableByteChannel])
var fileCreations = 0
var numChannelClosed = 0
when(downloadManager.createTempFile(any())).thenAnswer(_ => {
fileCreations += 1
simpleDownloadFile
})
when(simpleDownloadFile.openForWriting()).thenAnswer(_ => {
new simpleDownloadFile.SimpleDownloadWritableChannel(channel)
})
when(channel.close()).thenAnswer(_ => {
numChannelClosed += 1
})

service0.fetchBlocks("localhost", port, "exec1",
Array("block1"), listener, downloadManager)
assert(numChannelClosed == 1)
}

private def verifyServicePort(expectedPort: Int, actualPort: Int): Unit = {
actualPort should be >= expectedPort
// avoid testing equality in case of simultaneous tests
Expand Down

0 comments on commit a93d0f3

Please sign in to comment.