diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala index 5b34104a3f29..07d55dec1e03 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala @@ -23,7 +23,6 @@ import org.apache.gluten.runtime.Runtimes import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.utils.ArrowAbiUtil import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, NativeColumnarToRowJniWrapper} - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, Expression, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.joins.BuildSideRelation @@ -31,12 +30,13 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.utils.SparkArrowUtil import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.task.TaskResources - import org.apache.arrow.c.ArrowSchema +import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector +import org.apache.spark.sql.types.DataTypes import scala.collection.JavaConverters.asScalaIteratorConverter -case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Array[Byte]]) +case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: UnsafeArray) extends BuildSideRelation { override def deserialized: Iterator[ColumnarBatch] = { @@ -60,15 +60,16 @@ case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Arra var batchId = 0 override def hasNext: Boolean = { - batchId < batches.length + batchId < batches.getLength } override def next: ColumnarBatch = { - val handle = - jniWrapper - .deserialize(serializeHandle, batches(batchId)) + val batch = batches.get(batchId) + val columnVector = new OffHeapColumnVector(batch.numElements(), DataTypes.BinaryType) + columnVector.putByteArray(batchId, batch.toByteArray, batch.getBaseOffset.toInt, batch.numElements) + val columnarBatch = new ColumnarBatch(Array(columnVector)) batchId += 1 - ColumnarBatches.create(handle) + columnarBatch } }) .protectInvocationFlow() diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/UnsafeArray.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/UnsafeArray.scala new file mode 100644 index 000000000000..940bb1a34c90 --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/UnsafeArray.scala @@ -0,0 +1,58 @@ +package org.apache.spark.sql.execution + +import org.apache.spark.memory.{MemoryConsumer, MemoryMode, SparkOutOfMemoryError, TaskMemoryManager} +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData +import org.apache.spark.unsafe.{Platform, UnsafeAlignedOffset} +import org.apache.spark.unsafe.memory.MemoryBlock + +class UnsafeArray(taskMemoryManager: TaskMemoryManager) extends MemoryConsumer(taskMemoryManager, MemoryMode.OFF_HEAP) { + + protected var page: MemoryBlock = null + acquirePage(taskMemoryManager.pageSizeBytes) + protected var base: AnyRef = page.getBaseObject + protected var pageCursor = 0 + private var keyOffsets: Array[Long] = null + protected var numRows = 0 + + def iterator() {} + + private def acquirePage(requiredSize: Long): Boolean = { + try page = allocatePage(requiredSize) + + catch { + case SparkOutOfMemoryError => + return false + } + base = page.getBaseObject + pageCursor = 0 + true + } + + def get(rowId: Int): UnsafeArrayData = { + val offset = keyOffsets(rowId) + val klen = UnsafeAlignedOffset.getSize(base, offset - UnsafeAlignedOffset.getUaoSize) + val result = new UnsafeArrayData + result.pointTo (base, offset, klen) + result + } + + def write(bytes: Array[Byte], inputOffset: Long, inputLength: Int): Unit = { + var offset: Long = page.getBaseOffset + pageCursor + val recordOffset = offset + + val uaoSize = UnsafeAlignedOffset.getUaoSize + + val recordLength = 2L * uaoSize + inputLength + 8L + + UnsafeAlignedOffset.putSize(base, offset, inputLength + uaoSize) + offset += 2L * uaoSize + Platform.copyMemory(bytes, inputOffset, base, offset, inputLength) + Platform.putLong(base, offset, 0) + + pageCursor += recordLength + keyOffsets(numRows) = recordOffset + 2L * uaoSize; + numRows += 1 + } + + override def spill(l: Long, memoryConsumer: MemoryConsumer): Long = ??? +}