From 7d3e290f13c85223a4909f435777ff0d34eb9c91 Mon Sep 17 00:00:00 2001 From: Tang Haodong Date: Wed, 10 Apr 2019 11:17:29 +0800 Subject: [PATCH] rebase SSO to the newly updated HPNL Signed-off-by: Haodong Tang --- .../shuffle/sort/SerializedShuffleWriter.java | 6 +- .../apache/spark/network/pmof/Client.scala | 61 ++++++ .../spark/network/pmof/ClientFactory.scala | 122 ++++++++++++ ...ervice.scala => PmofTransferService.scala} | 51 ++--- .../spark/network/pmof/RdmaClient.scala | 178 ------------------ .../network/pmof/RdmaClientFactory.scala | 49 ----- .../spark/network/pmof/RdmaClientPool.scala | 35 ---- .../pmof/{RdmaServer.scala => Server.scala} | 82 ++++---- .../shuffle/pmof/BaseShuffleWriter.scala | 6 +- .../spark/shuffle/pmof/MetadataResolver.scala | 11 +- .../shuffle/pmof/PmemShuffleWriter.scala | 6 +- .../shuffle/pmof/PmofShuffleManager.scala | 7 +- .../shuffle/pmof/RdmaShuffleReader.scala | 14 +- .../pmof/PersistentMemoryHandler.scala | 4 +- ... => PmofShuffleBlockFetcherIterator.scala} | 17 +- 15 files changed, 270 insertions(+), 379 deletions(-) create mode 100644 src/main/scala/org/apache/spark/network/pmof/Client.scala create mode 100644 src/main/scala/org/apache/spark/network/pmof/ClientFactory.scala rename src/main/scala/org/apache/spark/network/pmof/{RdmaTransferService.scala => PmofTransferService.scala} (62%) delete mode 100644 src/main/scala/org/apache/spark/network/pmof/RdmaClient.scala delete mode 100644 src/main/scala/org/apache/spark/network/pmof/RdmaClientFactory.scala delete mode 100644 src/main/scala/org/apache/spark/network/pmof/RdmaClientPool.scala rename src/main/scala/org/apache/spark/network/pmof/{RdmaServer.scala => Server.scala} (65%) rename src/main/scala/org/apache/spark/storage/pmof/{RdmaShuffleBlockFetcherIterator.scala => PmofShuffleBlockFetcherIterator.scala} (96%) diff --git a/src/main/java/org/apache/spark/shuffle/sort/SerializedShuffleWriter.java b/src/main/java/org/apache/spark/shuffle/sort/SerializedShuffleWriter.java index 885b7292..b58d308d 100644 --- a/src/main/java/org/apache/spark/shuffle/sort/SerializedShuffleWriter.java +++ b/src/main/java/org/apache/spark/shuffle/sort/SerializedShuffleWriter.java @@ -25,7 +25,7 @@ import org.apache.spark.shuffle.pmof.MetadataResolver; import org.apache.spark.storage.BlockManagerId; import org.apache.spark.storage.BlockManagerId$; -import org.apache.spark.network.pmof.RdmaTransferService; +import org.apache.spark.network.pmof.PmofTransferService; import scala.Option; import scala.Product2; import scala.collection.JavaConverters; @@ -266,8 +266,8 @@ void closeAndWriteOutput() throws IOException { } BlockManagerId shuffleServerId = blockManager.shuffleServerId(); if (enable_rdma) { - BlockManagerId blockManagerId = BlockManagerId$.MODULE$.apply(shuffleServerId.executorId(), RdmaTransferService.shuffleNodesMap().get(shuffleServerId.host()).get(), - RdmaTransferService.getTransferServiceInstance(blockManager, null, false).port(), shuffleServerId.topologyInfo()); + BlockManagerId blockManagerId = BlockManagerId$.MODULE$.apply(shuffleServerId.executorId(), PmofTransferService.shuffleNodesMap().get(shuffleServerId.host()).get(), + PmofTransferService.getTransferServiceInstance(blockManager, null, false).port(), shuffleServerId.topologyInfo()); mapStatus = MapStatus$.MODULE$.apply(blockManagerId, partitionLengths); } else { mapStatus = MapStatus$.MODULE$.apply(shuffleServerId, partitionLengths); diff --git a/src/main/scala/org/apache/spark/network/pmof/Client.scala b/src/main/scala/org/apache/spark/network/pmof/Client.scala new file mode 100644 index 00000000..b78ccd71 --- /dev/null +++ b/src/main/scala/org/apache/spark/network/pmof/Client.scala @@ -0,0 +1,61 @@ +package org.apache.spark.network.pmof + +import java.nio.ByteBuffer +import java.util.concurrent.ConcurrentHashMap + +import com.intel.hpnl.core.{Connection, EqService} +import org.apache.spark.shuffle.pmof.PmofShuffleManager + +class Client(clientFactory: ClientFactory, val shuffleManager: PmofShuffleManager, con: Connection) { + final val outstandingReceiveFetches: ConcurrentHashMap[Long, ReceivedCallback] = + new ConcurrentHashMap[Long, ReceivedCallback]() + final val outstandingReadFetches: ConcurrentHashMap[Int, ReadCallback] = + new ConcurrentHashMap[Int, ReadCallback]() + final val shuffleBufferMap: ConcurrentHashMap[Int, ShuffleBuffer] = new ConcurrentHashMap[Int, ShuffleBuffer]() + + def getEqService: EqService = clientFactory.eqService + + def read(shuffleBuffer: ShuffleBuffer, reqSize: Int, + rmaAddress: Long, rmaRkey: Long, localAddress: Int, + callback: ReadCallback, isDeferred: Boolean = false): Unit = { + if (!isDeferred) { + outstandingReadFetches.putIfAbsent(shuffleBuffer.getRdmaBufferId, callback) + shuffleBufferMap.putIfAbsent(shuffleBuffer.getRdmaBufferId, shuffleBuffer) + } + val ret = con.read(shuffleBuffer.getRdmaBufferId, localAddress, reqSize, rmaAddress, rmaRkey) + if (ret == -11) { + if (isDeferred) { + clientFactory.deferredReadList.addFirst( + new ClientDeferredRead(this, shuffleBuffer, reqSize, rmaAddress, rmaRkey, localAddress) + ) + } else { + clientFactory.deferredReadList.addLast( + new ClientDeferredRead(this, shuffleBuffer, reqSize, rmaAddress, rmaRkey, localAddress) + ) + } + } + } + + def send(byteBuffer: ByteBuffer, seq: Long, msgType: Byte, + callback: ReceivedCallback, isDeferred: Boolean): Unit = { + assert(con != null) + if (callback != null) { + outstandingReceiveFetches.putIfAbsent(seq, callback) + } + val sendBuffer = this.con.takeSendBuffer(false) + if (sendBuffer == null) { + if (isDeferred) { + clientFactory.deferredSendList.addFirst( + new ClientDeferredSend(this, byteBuffer, seq, msgType, callback) + ) + } else { + clientFactory.deferredSendList.addLast( + new ClientDeferredSend(this, byteBuffer, seq, msgType, callback) + ) + } + return + } + sendBuffer.put(byteBuffer, msgType, seq) + con.send(sendBuffer.remaining(), sendBuffer.getBufferId) + } +} diff --git a/src/main/scala/org/apache/spark/network/pmof/ClientFactory.scala b/src/main/scala/org/apache/spark/network/pmof/ClientFactory.scala new file mode 100644 index 00000000..1c68f132 --- /dev/null +++ b/src/main/scala/org/apache/spark/network/pmof/ClientFactory.scala @@ -0,0 +1,122 @@ +package org.apache.spark.network.pmof + +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingDeque} + +import com.intel.hpnl.core._ +import org.apache.spark.SparkConf +import org.apache.spark.shuffle.pmof.PmofShuffleManager + +import scala.collection.mutable.ArrayBuffer + +class ClientFactory(conf: SparkConf) { + final val SINGLE_BUFFER_SIZE: Int = PmofTransferService.CHUNKSIZE + final val BUFFER_NUM: Int = conf.getInt("spark.shuffle.pmof.client_buffer_nums", 16) + final val workers = conf.getInt("spark.shuffle.pmof.server_pool_size", 1) + + final val eqService = new EqService(workers, BUFFER_NUM, false).init() + final val cqService = new CqService(eqService).init() + + final val conArray: ArrayBuffer[Connection] = ArrayBuffer() + final val deferredSendList = new LinkedBlockingDeque[ClientDeferredSend]() + final val deferredReadList = new LinkedBlockingDeque[ClientDeferredRead]() + final val clientMap = new ConcurrentHashMap[InetSocketAddress, Client]() + final val conMap = new ConcurrentHashMap[Connection, Client]() + + def init(): Unit = { + eqService.initBufferPool(BUFFER_NUM, SINGLE_BUFFER_SIZE, BUFFER_NUM * 2) + cqService.addExternalEvent(new ExternalHandler { + override def handle(): Unit = { + handleDeferredSend() + handleDeferredRead() + } + }) + val clientRecvHandler = new ClientRecvHandler + val clientReadHandler = new ClientReadHandler + eqService.setRecvCallback(clientRecvHandler) + eqService.setReadCallback(clientReadHandler) + cqService.start() + } + + def createClient(shuffleManager: PmofShuffleManager, address: String, port: Int): Client = { + val socketAddress: InetSocketAddress = InetSocketAddress.createUnresolved(address, port) + var client = clientMap.get(socketAddress) + if (client == null) { + ClientFactory.this.synchronized { + client = clientMap.get(socketAddress) + if (client == null) { + val con = eqService.connect(address, port.toString, 0) + client = new Client(this, shuffleManager, con) + clientMap.put(socketAddress, client) + conMap.put(con, client) + } + } + } + client + } + + def stop(): Unit = { + cqService.shutdown() + } + + def waitToStop(): Unit = { + cqService.join() + eqService.shutdown() + eqService.join() + } + + def getEqService: EqService = eqService + + class ClientRecvHandler() extends Handler { + override def handle(con: Connection, rdmaBufferId: Int, blockBufferSize: Int): Unit = { + val buffer: HpnlBuffer = con.getRecvBuffer(rdmaBufferId) + val rpcMessage: ByteBuffer = buffer.get(blockBufferSize) + val seq = buffer.getSeq + val msgType = buffer.getType + val callback = conMap.get(con).outstandingReceiveFetches.get(seq) + if (msgType == 0.toByte) { + callback.onSuccess(null) + } else { + val metadataResolver = conMap.get(con).shuffleManager.metadataResolver + val blockInfoArray = metadataResolver.deserializeShuffleBlockInfo(rpcMessage) + callback.onSuccess(blockInfoArray) + } + } + } + + class ClientReadHandler() extends Handler { + override def handle(con: Connection, rdmaBufferId: Int, blockBufferSize: Int): Unit = { + def fun(v1: Int): Unit = { + conMap.get(con).shuffleBufferMap.remove(v1) + conMap.get(con).outstandingReadFetches.remove(v1) + } + + val callback = conMap.get(con).outstandingReadFetches.get(rdmaBufferId) + val shuffleBuffer = conMap.get(con).shuffleBufferMap.get(rdmaBufferId) + callback.onSuccess(shuffleBuffer, fun) + } + } + + def handleDeferredSend(): Unit = { + if (!deferredSendList.isEmpty) { + val deferredSend = deferredSendList.pollFirst() + deferredSend.client.send(deferredSend.byteBuffer, deferredSend.seq, + deferredSend.msgType, deferredSend.callback, isDeferred = true) + } + } + + def handleDeferredRead(): Unit = { + if (!deferredReadList.isEmpty) { + val deferredRead = deferredReadList.pollFirst() + deferredRead.client.read(deferredRead.shuffleBuffer, deferredRead.reqSize, + deferredRead.rmaAddress, deferredRead.rmaRkey, deferredRead.localAddress, null, isDeferred = true) + } + } +} + +class ClientDeferredSend(val client: Client, val byteBuffer: ByteBuffer, val seq: Long, val msgType: Byte, + val callback: ReceivedCallback) {} + +class ClientDeferredRead(val client: Client, val shuffleBuffer: ShuffleBuffer, val reqSize: Int, + val rmaAddress: Long, val rmaRkey: Long, val localAddress: Int) {} diff --git a/src/main/scala/org/apache/spark/network/pmof/RdmaTransferService.scala b/src/main/scala/org/apache/spark/network/pmof/PmofTransferService.scala similarity index 62% rename from src/main/scala/org/apache/spark/network/pmof/RdmaTransferService.scala rename to src/main/scala/org/apache/spark/network/pmof/PmofTransferService.scala index 11d9ac63..c158198e 100644 --- a/src/main/scala/org/apache/spark/network/pmof/RdmaTransferService.scala +++ b/src/main/scala/org/apache/spark/network/pmof/PmofTransferService.scala @@ -6,25 +6,19 @@ import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} import org.apache.spark.network.BlockDataManager import org.apache.spark.network.shuffle.{BlockFetchingListener, TempFileManager} -import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.pmof.{MetadataResolver, PmofShuffleManager} import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockId} import org.apache.spark.{SparkConf, SparkEnv} import scala.collection.mutable -class RdmaTransferService(conf: SparkConf, val shuffleManager: PmofShuffleManager, val hostname: String, - var port: Int, val supportRma: Boolean) extends TransferService { - final var server: RdmaServer = _ - final private var recvHandler: ServerRecvHandler = _ - final private var connectHandler: ServerConnectHandler = _ - final private var clientFactory: RdmaClientFactory = _ - private var appId: String = _ +class PmofTransferService(conf: SparkConf, val shuffleManager: PmofShuffleManager, + val hostname: String, var port: Int) extends TransferService { + final var server: Server = _ + final private var clientFactory: ClientFactory = _ private var nextReqId: AtomicLong = _ final val metadataResolver: MetadataResolver = this.shuffleManager.metadataResolver - private val serializer = new JavaSerializer(conf) - override def fetchBlocks(host: String, port: Int, executId: String, @@ -34,8 +28,8 @@ class RdmaTransferService(conf: SparkConf, val shuffleManager: PmofShuffleManage def fetchBlock(reqHost: String, reqPort: Int, rmaAddress: Long, rmaLength: Int, rmaRkey: Long, localAddress: Int, shuffleBuffer: ShuffleBuffer, - rdmaClient: RdmaClient, callback: ReadCallback): Unit = { - rdmaClient.read(shuffleBuffer, rmaLength, rmaAddress, rmaRkey, localAddress, callback) + client: Client, callback: ReadCallback): Unit = { + client.read(shuffleBuffer, rmaLength, rmaAddress, rmaRkey, localAddress, callback) } def fetchBlockInfo(blockIds: Array[BlockId], receivedCallback: ReceivedCallback): Unit = { @@ -45,12 +39,12 @@ class RdmaTransferService(conf: SparkConf, val shuffleManager: PmofShuffleManage def syncBlocksInfo(host: String, port: Int, byteBuffer: ByteBuffer, msgType: Byte, callback: ReceivedCallback): Unit = { - clientFactory.createClient(shuffleManager, host, port, supportRma = false). + clientFactory.createClient(shuffleManager, host, port). send(byteBuffer, nextReqId.getAndIncrement(), msgType, callback, isDeferred = false) } - def getClient(reqHost: String, reqPort: Int): RdmaClient = { - clientFactory.createClient(shuffleManager, reqHost, reqPort, supportRma = true) + def getClient(reqHost: String, reqPort: Int): Client = { + clientFactory.createClient(shuffleManager, reqHost, reqPort) } override def close(): Unit = { @@ -65,15 +59,11 @@ class RdmaTransferService(conf: SparkConf, val shuffleManager: PmofShuffleManage } def init(): Unit = { - this.server = new RdmaServer(conf, shuffleManager, hostname, port, supportRma) - this.appId = conf.getAppId - this.recvHandler = new ServerRecvHandler(server, appId, serializer) - this.connectHandler = new ServerConnectHandler(server) - this.server.setRecvHandler(this.recvHandler) - this.server.setConnectHandler(this.connectHandler) - this.clientFactory = new RdmaClientFactory(conf) + this.server = new Server(conf, shuffleManager, hostname, port) + this.clientFactory = new ClientFactory(conf) this.server.init() this.server.start() + this.clientFactory.init() this.port = server.port val random = new Random().nextInt(Integer.MAX_VALUE) this.nextReqId = new AtomicLong(random) @@ -82,32 +72,31 @@ class RdmaTransferService(conf: SparkConf, val shuffleManager: PmofShuffleManage override def init(blockDataManager: BlockDataManager): Unit = {} } -object RdmaTransferService { +object PmofTransferService { final val env: SparkEnv = SparkEnv.get final val conf: SparkConf = env.conf final val CHUNKSIZE: Int = conf.getInt("spark.shuffle.pmof.chunk_size", 4096*3) final val driverHost: String = conf.get("spark.driver.rhost", defaultValue = "172.168.0.43") final val driverPort: Int = conf.getInt("spark.driver.rport", defaultValue = 61000) - val shuffleNodes: Array[Array[String]] = + final val shuffleNodes: Array[Array[String]] = conf.get("spark.shuffle.pmof.node", defaultValue = "").split(",").map(_.split("-")) - val shuffleNodesMap: mutable.Map[String, String] = new mutable.HashMap[String, String]() + final val shuffleNodesMap: mutable.Map[String, String] = new mutable.HashMap[String, String]() for (array <- shuffleNodes) { shuffleNodesMap.put(array(0), array(1)) } private val initialized = new AtomicBoolean(false) - private var transferService: RdmaTransferService = _ + private var transferService: PmofTransferService = _ def getTransferServiceInstance(blockManager: BlockManager, shuffleManager: PmofShuffleManager = null, - isDriver: Boolean = false): RdmaTransferService = { + isDriver: Boolean = false): PmofTransferService = { if (!initialized.get()) { - RdmaTransferService.this.synchronized { + PmofTransferService.this.synchronized { if (initialized.get()) return transferService if (isDriver) { transferService = - new RdmaTransferService(conf, shuffleManager, driverHost, driverPort, false) + new PmofTransferService(conf, shuffleManager, driverHost, driverPort) } else { transferService = - new RdmaTransferService(conf, shuffleManager, shuffleNodesMap(blockManager.shuffleServerId.host), - 0, false) + new PmofTransferService(conf, shuffleManager, shuffleNodesMap(blockManager.shuffleServerId.host), 0) } transferService.init() initialized.set(true) diff --git a/src/main/scala/org/apache/spark/network/pmof/RdmaClient.scala b/src/main/scala/org/apache/spark/network/pmof/RdmaClient.scala deleted file mode 100644 index a1808e4e..00000000 --- a/src/main/scala/org/apache/spark/network/pmof/RdmaClient.scala +++ /dev/null @@ -1,178 +0,0 @@ -package org.apache.spark.network.pmof - -import java.nio.ByteBuffer -import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingDeque} - -import com.intel.hpnl.core._ -import org.apache.spark.SparkConf -import org.apache.spark.shuffle.pmof.PmofShuffleManager - -class RdmaClient(conf: SparkConf, val shuffleManager: PmofShuffleManager, address: String, port: Int, supportRma: Boolean) { - var SINGLE_BUFFER_SIZE: Int = _ - var BUFFER_NUM: Int = _ - if (supportRma) { - SINGLE_BUFFER_SIZE = 0 - BUFFER_NUM = 0 - } else { - SINGLE_BUFFER_SIZE = RdmaTransferService.CHUNKSIZE - BUFFER_NUM = conf.getInt("spark.shuffle.pmof.client_buffer_nums", 16) - } - final val eqService = new EqService(address, port.toString, 1, BUFFER_NUM, false).init() - final val cqService = new CqService(eqService, eqService.getNativeHandle).init() - - final val connectHandler = new ClientConnectHandler(this) - final val recvHandler = new ClientRecvHandler(this) - final val readHandler = new ClientReadHandler(this) - final val started: AtomicBoolean = new AtomicBoolean(false) - - val outstandingReceiveFetches: ConcurrentHashMap[Long, ReceivedCallback] = - new ConcurrentHashMap[Long, ReceivedCallback]() - val outstandingReadFetches: ConcurrentHashMap[Int, ReadCallback] = - new ConcurrentHashMap[Int, ReadCallback]() - - val shuffleBufferMap: ConcurrentHashMap[Int, ShuffleBuffer] = new ConcurrentHashMap[Int, ShuffleBuffer]() - - private var con: Connection = _ - - private val deferredReqList = new LinkedBlockingDeque[ClientDeferredReq]() - private val deferredReadList = new LinkedBlockingDeque[ClientDeferredRead]() - - def init(): Unit = { - eqService.initBufferPool(BUFFER_NUM, SINGLE_BUFFER_SIZE, BUFFER_NUM*2) - - cqService.addExternalEvent(new ExternalHandler { - override def handle(): Unit = { - handleDeferredReq() - handleDeferredRead() - } - }) - } - - def start(): Unit = { - eqService.setConnectedCallback(connectHandler) - eqService.setRecvCallback(recvHandler) - eqService.setReadCallback(readHandler) - - eqService.start() - cqService.start() - eqService.waitToConnected() - - started.set(true) - } - - def stop(): Unit = { - cqService.shutdown() - } - - def waitToStop(): Unit = { - cqService.join() - eqService.shutdown() - eqService.join() - } - - def setCon(con: Connection): Unit = { - this.con = con - } - - def getCon: Connection = { - assert(this.con != null) - this.con - } - - def handleDeferredReq(): Unit = { - if (!deferredReqList.isEmpty) { - val deferredReq = deferredReqList.pollFirst() - val byteBuffer = deferredReq.byteBuffer - val seq = deferredReq.seq - val msgType = deferredReq.msgType - val callback = deferredReq.callback - send(byteBuffer, seq, msgType, callback, isDeferred = true) - } - } - - def handleDeferredRead(): Unit = { - if (!deferredReadList.isEmpty) { - val deferredRead = deferredReadList.pollFirst() - read(deferredRead.shuffleBuffer, deferredRead.reqSize, deferredRead.rmaAddress, deferredRead.rmaRkey, deferredRead.localAddress, null, isDeferred = true) - } - } - - def read(shuffleBuffer: ShuffleBuffer, reqSize: Int, - rmaAddress: Long, rmaRkey: Long, localAddress: Int, - callback: ReadCallback, isDeferred: Boolean = false): Unit = { - if (!isDeferred) { - outstandingReadFetches.putIfAbsent(shuffleBuffer.getRdmaBufferId, callback) - shuffleBufferMap.putIfAbsent(shuffleBuffer.getRdmaBufferId, shuffleBuffer) - } - val ret = con.read(shuffleBuffer.getRdmaBufferId, localAddress, reqSize, rmaAddress, rmaRkey) - if (ret == -11) { - if (isDeferred) { - deferredReadList.addFirst(new ClientDeferredRead(shuffleBuffer, reqSize, rmaAddress, rmaRkey, localAddress)) - } else { - deferredReadList.addLast(new ClientDeferredRead(shuffleBuffer, reqSize, rmaAddress, rmaRkey, localAddress)) - } - } - } - - def send(byteBuffer: ByteBuffer, seq: Long, msgType: Byte, - callback: ReceivedCallback, isDeferred: Boolean): Unit = { - assert(con != null) - if (callback != null) { - outstandingReceiveFetches.putIfAbsent(seq, callback) - } - val sendBuffer = this.con.takeSendBuffer(false) - if (sendBuffer == null) { - if (isDeferred) { - deferredReqList.addFirst(new ClientDeferredReq(byteBuffer, seq, msgType, callback)) - } else { - deferredReqList.addLast(new ClientDeferredReq(byteBuffer, seq, msgType, callback)) - } - return - } - sendBuffer.put(byteBuffer, msgType, seq) - con.send(sendBuffer.remaining(), sendBuffer.getRdmaBufferId) - } - - def getEqService: EqService = eqService -} - -class ClientConnectHandler(client: RdmaClient) extends Handler { - override def handle(connection: Connection, rdmaBufferId: Int, bufferBufferSize: Int): Unit = { - client.setCon(connection) - } -} - -class ClientRecvHandler(client: RdmaClient) extends Handler { - override def handle(con: Connection, rdmaBufferId: Int, blockBufferSize: Int): Unit = { - val buffer: RdmaBuffer = con.getRecvBuffer(rdmaBufferId) - val rpcMessage: ByteBuffer = buffer.get(blockBufferSize) - val seq = buffer.getSeq - val msgType = buffer.getType - val callback = client.outstandingReceiveFetches.get(seq) - if (msgType == 0.toByte) { - callback.onSuccess(null) - } else { - val metadataResolver = client.shuffleManager.metadataResolver - val blockInfoArray = metadataResolver.deserializeShuffleBlockInfo(rpcMessage) - callback.onSuccess(blockInfoArray) - } - } -} - -class ClientReadHandler(client: RdmaClient) extends Handler { - def fun(v1: Int): Unit = { - client.shuffleBufferMap.remove(v1) - client.outstandingReadFetches.remove(v1) - } - override def handle(con: Connection, rdmaBufferId: Int, blockBufferSize: Int): Unit = { - val callback = client.outstandingReadFetches.get(rdmaBufferId) - val shuffleBuffer = client.shuffleBufferMap.get(rdmaBufferId) - callback.onSuccess(shuffleBuffer, fun) - } -} - -class ClientDeferredReq(val byteBuffer: ByteBuffer, val seq: Long, val msgType: Byte, - val callback: ReceivedCallback) {} - -class ClientDeferredRead(val shuffleBuffer: ShuffleBuffer, val reqSize: Int, val rmaAddress: Long, val rmaRkey: Long, val localAddress: Int) {} diff --git a/src/main/scala/org/apache/spark/network/pmof/RdmaClientFactory.scala b/src/main/scala/org/apache/spark/network/pmof/RdmaClientFactory.scala deleted file mode 100644 index aff739c4..00000000 --- a/src/main/scala/org/apache/spark/network/pmof/RdmaClientFactory.scala +++ /dev/null @@ -1,49 +0,0 @@ -package org.apache.spark.network.pmof - -import java.net.{InetSocketAddress, SocketAddress} -import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.atomic.AtomicInteger - -import org.apache.spark.SparkConf -import org.apache.spark.shuffle.pmof.PmofShuffleManager - -import scala.collection.JavaConverters._ - -class RdmaClientFactory(conf: SparkConf) { - final val CON_NUM: Int = conf.getInt("spark.shuffle.pmof.client_pool_size", 2) - val nextReqId: AtomicInteger = new AtomicInteger(0) - val conPools: ConcurrentHashMap[SocketAddress, RdmaClientPool] = - new ConcurrentHashMap[SocketAddress, RdmaClientPool]() - - def createClient(shuffleManager: PmofShuffleManager, address: String, port: Int, supportRma: Boolean): RdmaClient = { - val socketAddress: InetSocketAddress = InetSocketAddress.createUnresolved(address, port) - var conPool: RdmaClientPool = conPools.get(socketAddress) - if (conPool == null) { - RdmaClientFactory.this.synchronized { - conPool = conPools.get(socketAddress) - if (conPool == null) { - conPool = new RdmaClientPool(conf, shuffleManager, CON_NUM, address, port) - conPools.put(socketAddress, conPool) - } - } - } - if (!supportRma) { - conPool.get(0) - } else { - val reqId = nextReqId.getAndIncrement() - conPool.get(reqId%(CON_NUM-1)+1) - } - } - - def stop(): Unit = { - for ((_, v) <- conPools.asScala) { - v.stop() - } - } - - def waitToStop(): Unit = { - for ((_, v) <- conPools.asScala) { - v.waitToStop() - } - } -} diff --git a/src/main/scala/org/apache/spark/network/pmof/RdmaClientPool.scala b/src/main/scala/org/apache/spark/network/pmof/RdmaClientPool.scala deleted file mode 100644 index 3414eef7..00000000 --- a/src/main/scala/org/apache/spark/network/pmof/RdmaClientPool.scala +++ /dev/null @@ -1,35 +0,0 @@ -package org.apache.spark.network.pmof - -import org.apache.spark.SparkConf -import org.apache.spark.shuffle.pmof.PmofShuffleManager - -class RdmaClientPool(conf: SparkConf, shuffleManager: PmofShuffleManager, poolSize: Int, address: String, port: Int) { - val RdmaClients = new Array[RdmaClient](poolSize) - - def get(index: Int): RdmaClient = { - if (RdmaClients(index) == null || !RdmaClients(index).started.get()) { - RdmaClientPool.this.synchronized { - if (RdmaClients(index) == null || !RdmaClients(index).started.get()) { - if (index == 0) { - RdmaClients(0) = new RdmaClient(conf, shuffleManager, address, port, false) - RdmaClients(0).init() - RdmaClients(0).start() - } else { - RdmaClients(index) = new RdmaClient(conf, shuffleManager, address, port, true) - RdmaClients(index).init() - RdmaClients(index).start() - } - } - } - } - RdmaClients(index) - } - - def stop(): Unit = { - RdmaClients.foreach(_.stop()) - } - - def waitToStop(): Unit = { - RdmaClients.foreach(_.waitToStop()) - } -} diff --git a/src/main/scala/org/apache/spark/network/pmof/RdmaServer.scala b/src/main/scala/org/apache/spark/network/pmof/Server.scala similarity index 65% rename from src/main/scala/org/apache/spark/network/pmof/RdmaServer.scala rename to src/main/scala/org/apache/spark/network/pmof/Server.scala index 149db219..9008fa25 100644 --- a/src/main/scala/org/apache/spark/network/pmof/RdmaServer.scala +++ b/src/main/scala/org/apache/spark/network/pmof/Server.scala @@ -4,41 +4,40 @@ import java.nio.ByteBuffer import java.util import java.util.concurrent.LinkedBlockingDeque -import org.apache.spark.serializer.Serializer import com.intel.hpnl.core._ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.shuffle.pmof.PmofShuffleManager -class RdmaServer(conf: SparkConf, val shuffleManager: PmofShuffleManager, address: String, var port: Int, supportRma: Boolean) { +class Server(conf: SparkConf, val shuffleManager: PmofShuffleManager, address: String, var port: Int) { if (port == 0) { port = Utils.getPort } - var SINGLE_BUFFER_SIZE: Int = _ - var BUFFER_NUM: Int = _ - - if (supportRma) { - SINGLE_BUFFER_SIZE = 0 - BUFFER_NUM = 0 - } else { - SINGLE_BUFFER_SIZE = RdmaTransferService.CHUNKSIZE - BUFFER_NUM = conf.getInt("spark.shuffle.pmof.server_buffer_nums", 256) - } - + final val SINGLE_BUFFER_SIZE: Int = PmofTransferService.CHUNKSIZE + final val BUFFER_NUM: Int = conf.getInt("spark.shuffle.pmof.server_buffer_nums", 256) final val workers = conf.getInt("spark.shuffle.pmof.server_pool_size", 1) - final val eqService = new EqService(address, port.toString, workers, BUFFER_NUM, true).init() - final val cqService = new CqService(eqService, eqService.getNativeHandle).init() + final val eqService = new EqService(workers, BUFFER_NUM, true).init() + final val cqService = new CqService(eqService).init() val conList = new util.ArrayList[Connection]() def init(): Unit = { - eqService.initBufferPool(BUFFER_NUM, SINGLE_BUFFER_SIZE, BUFFER_NUM*2) + eqService.initBufferPool(BUFFER_NUM, SINGLE_BUFFER_SIZE, BUFFER_NUM * 2) + val recvHandler = new ServerRecvHandler(this) + val connectedHandler = new ServerConnectedHandler(this) + cqService.addExternalEvent(new ExternalHandler { + override def handle(): Unit = { + recvHandler.handleDeferredSend() + } + }) + eqService.setConnectedCallback(connectedHandler) + eqService.setRecvCallback(recvHandler) } def start(): Unit = { - eqService.start() cqService.start() + eqService.listen(address, port.toString) } def stop(): Unit = { @@ -51,19 +50,6 @@ class RdmaServer(conf: SparkConf, val shuffleManager: PmofShuffleManager, addres eqService.join() } - def setRecvHandler(handler: Handler): Unit = { - eqService.setRecvCallback(handler) - cqService.addExternalEvent(new ExternalHandler { - override def handle(): Unit = { - handler.asInstanceOf[ServerRecvHandler] handleDeferredReq() - } - }) - } - - def setConnectHandler(handler: Handler): Unit = { - eqService.setConnectedCallback(handler) - } - def getEqService: EqService = { eqService } @@ -71,16 +57,12 @@ class RdmaServer(conf: SparkConf, val shuffleManager: PmofShuffleManager, addres def addCon(con: Connection): Unit = synchronized { conList.add(con) } - - def getConSize: Int = synchronized { - conList.size() - } } -class ServerRecvHandler(server: RdmaServer, appid: String, serializer: Serializer) extends Handler with Logging { +class ServerRecvHandler(server: Server) extends Handler with Logging { - private final val deferredBufferList = new LinkedBlockingDeque[ServerDeferredReq]() private final val byteBufferTmp = ByteBuffer.allocate(4) + private final val deferredBufferList = new LinkedBlockingDeque[ServerDeferredReq]() def sendMetadata(con: Connection, byteBuffer: ByteBuffer, msgType: Byte, seq: Long, isDeferred: Boolean): Unit = { val sendBuffer = con.takeSendBuffer(false) @@ -93,21 +75,11 @@ class ServerRecvHandler(server: RdmaServer, appid: String, serializer: Serialize return } sendBuffer.put(byteBuffer, msgType, seq) - con.send(sendBuffer.remaining(), sendBuffer.getRdmaBufferId) + con.send(sendBuffer.remaining(), sendBuffer.getBufferId) } - def handleDeferredReq(): Unit = { - val deferredReq = deferredBufferList.pollFirst - if (deferredReq == null) return - val con = deferredReq.con - val byteBuffer = deferredReq.byteBuffer - val msgType = deferredReq.msgType - val seq = deferredReq.seq - sendMetadata(con, byteBuffer, msgType, seq, isDeferred = true) - } - - override def handle(con: Connection, rdmaBufferId: Int, blockBufferSize: Int): Unit = synchronized { - val buffer: RdmaBuffer = con.getRecvBuffer(rdmaBufferId) + override def handle(con: Connection, bufferId: Int, blockBufferSize: Int): Unit = synchronized { + val buffer: HpnlBuffer = con.getRecvBuffer(bufferId) val message: ByteBuffer = buffer.get(blockBufferSize) val seq = buffer.getSeq val msgType = buffer.getType @@ -120,9 +92,19 @@ class ServerRecvHandler(server: RdmaServer, appid: String, serializer: Serialize sendMetadata(con, outputBuffer, 1.toByte, seq, isDeferred = false) } } + + def handleDeferredSend(): Unit = { + val deferredReq = deferredBufferList.pollFirst + if (deferredReq == null) return + val con = deferredReq.con + val byteBuffer = deferredReq.byteBuffer + val msgType = deferredReq.msgType + val seq = deferredReq.seq + sendMetadata(con, byteBuffer, msgType, seq, isDeferred = true) + } } -class ServerConnectHandler(server: RdmaServer) extends Handler { +class ServerConnectedHandler(server: Server) extends Handler { override def handle(con: Connection, rdmaBufferId: Int, bufferBufferSize: Int): Unit = { server.addCon(con) } diff --git a/src/main/scala/org/apache/spark/shuffle/pmof/BaseShuffleWriter.scala b/src/main/scala/org/apache/spark/shuffle/pmof/BaseShuffleWriter.scala index a180de43..bab5e3a1 100644 --- a/src/main/scala/org/apache/spark/shuffle/pmof/BaseShuffleWriter.scala +++ b/src/main/scala/org/apache/spark/shuffle/pmof/BaseShuffleWriter.scala @@ -19,7 +19,7 @@ package org.apache.spark.shuffle.pmof import org.apache.spark._ import org.apache.spark.internal.Logging -import org.apache.spark.network.pmof.RdmaTransferService +import org.apache.spark.network.pmof.PmofTransferService import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter} import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} @@ -80,8 +80,8 @@ private[spark] class BaseShuffleWriter[K, V, C]( val shuffleServerId = blockManager.shuffleServerId if (enable_rdma) { val blockManagerId: BlockManagerId = - BlockManagerId(shuffleServerId.executorId, RdmaTransferService.shuffleNodesMap(shuffleServerId.host), - RdmaTransferService.getTransferServiceInstance(blockManager).port, shuffleServerId.topologyInfo) + BlockManagerId(shuffleServerId.executorId, PmofTransferService.shuffleNodesMap(shuffleServerId.host), + PmofTransferService.getTransferServiceInstance(blockManager).port, shuffleServerId.topologyInfo) mapStatus = MapStatus(blockManagerId, partitionLengths) } else { mapStatus = MapStatus(shuffleServerId, partitionLengths) diff --git a/src/main/scala/org/apache/spark/shuffle/pmof/MetadataResolver.scala b/src/main/scala/org/apache/spark/shuffle/pmof/MetadataResolver.scala index 161291d7..5029d320 100644 --- a/src/main/scala/org/apache/spark/shuffle/pmof/MetadataResolver.scala +++ b/src/main/scala/org/apache/spark/shuffle/pmof/MetadataResolver.scala @@ -15,7 +15,6 @@ import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId} import scala.collection.JavaConverters._ -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer class MetadataResolver(conf: SparkConf) { @@ -86,7 +85,7 @@ class MetadataResolver(conf: SparkConf) { } } - RdmaTransferService.getTransferServiceInstance(null, null). + PmofTransferService.getTransferServiceInstance(null, null). syncBlocksInfo(driverHost, driverPort, byteBuffer, 0.toByte, receivedCallback) latch.await() @@ -103,11 +102,11 @@ class MetadataResolver(conf: SparkConf) { totalLength = totalLength + currentLength } - val eqService = RdmaTransferService.getTransferServiceInstance(blockManager).server.getEqService + val eqService = PmofTransferService.getTransferServiceInstance(blockManager).server.getEqService val shuffleBuffer = new ShuffleBuffer(0, totalLength, channel, eqService) val startedAddress = shuffleBuffer.getAddress val rdmaBuffer = eqService.regRmaBufferByAddress(shuffleBuffer.nioByteBuffer(), startedAddress, totalLength.toInt) - shuffleBuffer.setRdmaBufferId(rdmaBuffer.getRdmaBufferId) + shuffleBuffer.setRdmaBufferId(rdmaBuffer.getBufferId) shuffleBuffer.setRkey(rdmaBuffer.getRKey) val blockId = ShuffleBlockId(shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) blockMap.put(blockId.name, shuffleBuffer) @@ -154,7 +153,7 @@ class MetadataResolver(conf: SparkConf) { } } - RdmaTransferService.getTransferServiceInstance(null, null). + PmofTransferService.getTransferServiceInstance(null, null). syncBlocksInfo(driverHost, driverPort, byteBuffer, 0.toByte, receivedCallback) latch.await() } @@ -175,7 +174,7 @@ class MetadataResolver(conf: SparkConf) { byteBufferTmp.putInt(blockIds(i).reduceId) } byteBufferTmp.flip() - RdmaTransferService.getTransferServiceInstance(null, null). + PmofTransferService.getTransferServiceInstance(null, null). syncBlocksInfo(driverHost, driverPort, byteBufferTmp, 1.toByte, receivedCallback) } diff --git a/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriter.scala b/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriter.scala index 6d30b2b3..9cc43275 100644 --- a/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriter.scala +++ b/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriter.scala @@ -19,7 +19,7 @@ package org.apache.spark.shuffle.pmof import org.apache.spark._ import org.apache.spark.internal.Logging -import org.apache.spark.network.pmof.RdmaTransferService +import org.apache.spark.network.pmof.PmofTransferService import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter} import org.apache.spark.shuffle.pmof._ @@ -129,8 +129,8 @@ private[spark] class PmemShuffleWriter[K, V, C]( val rkey = partitionBufferArray(0).getRkey() metadataResolver.commitPmemBlockInfo(stageId, mapId, data_addr_map, rkey) val blockManagerId: BlockManagerId = - BlockManagerId(shuffleServerId.executorId, RdmaTransferService.shuffleNodesMap(shuffleServerId.host), - RdmaTransferService.getTransferServiceInstance(blockManager).port, shuffleServerId.topologyInfo) + BlockManagerId(shuffleServerId.executorId, PmofTransferService.shuffleNodesMap(shuffleServerId.host), + PmofTransferService.getTransferServiceInstance(blockManager).port, shuffleServerId.topologyInfo) mapStatus = MapStatus(blockManagerId, partitionLengths) } else { mapStatus = MapStatus(shuffleServerId, partitionLengths) diff --git a/src/main/scala/org/apache/spark/shuffle/pmof/PmofShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/pmof/PmofShuffleManager.scala index 979f7b35..3bcc44f4 100644 --- a/src/main/scala/org/apache/spark/shuffle/pmof/PmofShuffleManager.scala +++ b/src/main/scala/org/apache/spark/shuffle/pmof/PmofShuffleManager.scala @@ -3,9 +3,8 @@ package org.apache.spark.shuffle.pmof import java.util.concurrent.ConcurrentHashMap import org.apache.spark.internal.Logging -import org.apache.spark.network.pmof.RdmaTransferService +import org.apache.spark.network.pmof.PmofTransferService import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.pmof._ import org.apache.spark.shuffle.sort.{SerializedShuffleHandle, SerializedShuffleWriter, SortShuffleManager} import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext} @@ -26,7 +25,7 @@ private[spark] class PmofShuffleManager(conf: SparkConf) extends ShuffleManager override def registerShuffle[K, V, C](shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { val env: SparkEnv = SparkEnv.get if (enable_rdma) { - RdmaTransferService.getTransferServiceInstance(env.blockManager, this, isDriver = true) + PmofTransferService.getTransferServiceInstance(env.blockManager, this, isDriver = true) } if (enable_pmem) { new BaseShuffleHandle(shuffleId, numMaps, dependency) @@ -45,7 +44,7 @@ private[spark] class PmofShuffleManager(conf: SparkConf) extends ShuffleManager val numMaps = handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps numMapsForShuffle.putIfAbsent(handle.shuffleId, numMaps) if (enable_rdma) { - RdmaTransferService.getTransferServiceInstance(env.blockManager, this) + PmofTransferService.getTransferServiceInstance(env.blockManager, this) } handle match { case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => diff --git a/src/main/scala/org/apache/spark/shuffle/pmof/RdmaShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/pmof/RdmaShuffleReader.scala index 9fae22fd..653b9679 100644 --- a/src/main/scala/org/apache/spark/shuffle/pmof/RdmaShuffleReader.scala +++ b/src/main/scala/org/apache/spark/shuffle/pmof/RdmaShuffleReader.scala @@ -1,13 +1,11 @@ package org.apache.spark.shuffle.pmof -import java.util.UUID - import org.apache.spark._ import org.apache.spark.internal.{Logging, config} -import org.apache.spark.network.pmof.RdmaTransferService -import org.apache.spark.serializer.SerializerManager +import org.apache.spark.network.pmof.PmofTransferService +import org.apache.spark.serializer.{SerializerInstance, SerializerManager} import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} -import org.apache.spark.storage.{BlockManager, TempShuffleBlockId, BlockId} +import org.apache.spark.storage.BlockManager import org.apache.spark.storage.pmof._ import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -28,14 +26,14 @@ private[spark] class RdmaShuffleReader[K, C]( extends ShuffleReader[K, C] with Logging { private val dep = handle.dependency - val serializerInstance = dep.serializer.newInstance() + val serializerInstance: SerializerInstance = dep.serializer.newInstance() val enable_pmem: Boolean = SparkEnv.get.conf.getBoolean("spark.shuffle.pmof.enable_pmem", defaultValue = true) /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { val wrappedStreams: RdmaShuffleBlockFetcherIterator = new RdmaShuffleBlockFetcherIterator( context, - RdmaTransferService.getTransferServiceInstance(blockManager), + PmofTransferService.getTransferServiceInstance(blockManager), blockManager, mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), serializerManager.wrapStream, @@ -87,7 +85,7 @@ private[spark] class RdmaShuffleReader[K, C]( // Sort the output if there is a sort ordering defined. dep.keyOrdering match { case Some(keyOrd: Ordering[K]) => - if (enable_pmem == true) { + if (enable_pmem) { val sorter = new PmemExternalSorter[K, C, C](context, handle, ordering = Some(keyOrd), serializer = dep.serializer) sorter.insertAll(aggregatedIter) CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) diff --git a/src/main/scala/org/apache/spark/storage/pmof/PersistentMemoryHandler.scala b/src/main/scala/org/apache/spark/storage/pmof/PersistentMemoryHandler.scala index e3fb334b..5b25a3d4 100644 --- a/src/main/scala/org/apache/spark/storage/pmof/PersistentMemoryHandler.scala +++ b/src/main/scala/org/apache/spark/storage/pmof/PersistentMemoryHandler.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage.pmof import org.apache.spark.internal.Logging -import org.apache.spark.network.pmof.RdmaTransferService +import org.apache.spark.network.pmof.PmofTransferService import org.apache.spark.SparkEnv import scala.collection.JavaConverters._ @@ -156,7 +156,7 @@ object PersistentMemoryHandler { persistentMemoryHandler.log("Use persistentMemoryHandler Object: " + this) if (enable_rdma) { val blockManager = SparkEnv.get.blockManager - val eqService = RdmaTransferService.getTransferServiceInstance(blockManager).server.getEqService + val eqService = PmofTransferService.getTransferServiceInstance(blockManager).server.getEqService val size: Long = 264239054848L val offset: Long = persistentMemoryHandler.getRootAddr val rdmaBuffer = eqService.regRmaBufferByAddress(null, offset, size) diff --git a/src/main/scala/org/apache/spark/storage/pmof/RdmaShuffleBlockFetcherIterator.scala b/src/main/scala/org/apache/spark/storage/pmof/PmofShuffleBlockFetcherIterator.scala similarity index 96% rename from src/main/scala/org/apache/spark/storage/pmof/RdmaShuffleBlockFetcherIterator.scala rename to src/main/scala/org/apache/spark/storage/pmof/PmofShuffleBlockFetcherIterator.scala index 1b3e7f32..5f113d57 100644 --- a/src/main/scala/org/apache/spark/storage/pmof/RdmaShuffleBlockFetcherIterator.scala +++ b/src/main/scala/org/apache/spark/storage/pmof/PmofShuffleBlockFetcherIterator.scala @@ -162,7 +162,7 @@ final class RdmaShuffleBlockFetcherIterator( numBlocksToFetch += blockIds.length - val rdmaTransferService = shuffleClient.asInstanceOf[RdmaTransferService] + val rdmaTransferService = shuffleClient.asInstanceOf[PmofTransferService] rdmaTransferService.fetchBlockInfo(blockIds, receivedCallback) } @@ -177,7 +177,7 @@ final class RdmaShuffleBlockFetcherIterator( numBlocksInFlightPerAddress(blockManagerId) = numBlocksInFlightPerAddress.getOrElse(blockManagerId, 0) + 1 - val rdmaTransferService = shuffleClient.asInstanceOf[RdmaTransferService] + val pmofTransferService = shuffleClient.asInstanceOf[PmofTransferService] val blockFetchingReadCallback = new ReadCallback { def onSuccess(shuffleBuffer: ShuffleBuffer, f: Int => Unit): Unit = { @@ -197,14 +197,17 @@ final class RdmaShuffleBlockFetcherIterator( } } - val rdmaClient = rdmaTransferService.getClient(blockManagerId.host, blockManagerId.port) - val shuffleBuffer = new ShuffleBuffer(rdmaRequest.reqSize, rdmaClient.getEqService, true) - val rdmaBuffer = rdmaClient.getEqService.regRmaBufferByAddress(shuffleBuffer.nioByteBuffer(), shuffleBuffer.getAddress, shuffleBuffer.getLength.toInt) - shuffleBuffer.setRdmaBufferId(rdmaBuffer.getRdmaBufferId) + val client = pmofTransferService.getClient(blockManagerId.host, blockManagerId.port) + val shuffleBuffer = new ShuffleBuffer(rdmaRequest.reqSize, client.getEqService, true) + val rdmaBuffer = client.getEqService.regRmaBufferByAddress(shuffleBuffer.nioByteBuffer(), + shuffleBuffer.getAddress, shuffleBuffer.getLength.toInt) + shuffleBuffer.setRdmaBufferId(rdmaBuffer.getBufferId) var offset = 0 for (i <- 0 until partitionNums) { - rdmaTransferService.fetchBlock(blockManagerId.host, blockManagerId.port, shuffleBlockInfos(i).getAddress, shuffleBlockInfos(i).getLength, shuffleBlockInfos(i).getRkey, offset, shuffleBuffer, rdmaClient, blockFetchingReadCallback) + pmofTransferService.fetchBlock(blockManagerId.host, blockManagerId.port, + shuffleBlockInfos(i).getAddress, shuffleBlockInfos(i).getLength, + shuffleBlockInfos(i).getRkey, offset, shuffleBuffer, client, blockFetchingReadCallback) offset += shuffleBlockInfos(i).getLength } }