Skip to content

Commit

Permalink
squashed commit: rebase + added spark transform apply
Browse files Browse the repository at this point in the history
  • Loading branch information
e-strauss committed Aug 11, 2023
1 parent 7115b37 commit 6bbef0c
Show file tree
Hide file tree
Showing 17 changed files with 524 additions and 196 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ protected void allocateBlock(int bix, int length) {

@Override
public void reset(int rlen, int[] odims, double v) {
if(rlen > capacity() / _odims[0])
if(rlen > _rlen)
_data = new double[rlen][];
else {
if(v == 0.0) {
else{
if(_data == null)
_data = new double[rlen][];
if(v == 0.0){
for(int i = 0; i < rlen; i++)
_data[i] = null;
}
Expand Down Expand Up @@ -177,6 +179,12 @@ public int pos(int[] ix){
public int blockSize(int bix) {
return 1;
}

@Override
public boolean isContiguous() {
return false;
}
@Override
public boolean isContiguous(int rl, int ru){
return rl == ru;
}
Expand Down Expand Up @@ -251,6 +259,25 @@ public DenseBlock set(DenseBlock db) {
throw new NotImplementedException();
}

@Override
public DenseBlock set(int rl, int ru, int ol, int ou, DenseBlock db) {
if( !(db instanceof DenseBlockFP64DEDUP))
throw new NotImplementedException();
HashMap<double[], double[]> cache = new HashMap<>();
int len = ou - ol;
for(int i=rl, ix1 = 0; i<ru; i++, ix1++){
double[] row = db.values(ix1);
double[] newRow = cache.get(row);
if (newRow == null) {
newRow = new double[len];
System.arraycopy(row, 0, newRow, 0, len);
cache.put(row, newRow);
}
set(i, newRow);
}
return this;
}

@Override
public DenseBlock set(int[] ix, double v) {
return set(ix[0], pos(ix), v);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,13 @@ public static DenseBlock createDenseBlock(ValueType vt, int[] dims) {
}

public static DenseBlock createDenseBlock(ValueType vt, int[] dims, boolean dedup) {
DenseBlock.Type type = (UtilFunctions.prod(dims) < Integer.MAX_VALUE) ?
DenseBlock.Type.DRB : DenseBlock.Type.LDRB;
DenseBlock.Type type;
if(dedup)
type = (dims[0] < Integer.MAX_VALUE) ?
DenseBlock.Type.DRB : DenseBlock.Type.LDRB;
else
type = (UtilFunctions.prod(dims) < Integer.MAX_VALUE) ?
DenseBlock.Type.DRB : DenseBlock.Type.LDRB;
return createDenseBlock(vt, type, dims, dedup);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@
import org.apache.sysds.runtime.instructions.spark.functions.PerformGroupByAggInCombiner;
import org.apache.sysds.runtime.instructions.spark.functions.PerformGroupByAggInReducer;
import org.apache.sysds.runtime.instructions.spark.functions.ReplicateVectorFunction;
import org.apache.sysds.runtime.instructions.spark.utils.FrameRDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.FrameRDDConverterUtils;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.RDDConverterUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
Expand All @@ -77,6 +79,7 @@
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
import org.apache.sysds.runtime.transform.decode.Decoder;
import org.apache.sysds.runtime.transform.decode.DecoderFactory;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
Expand Down Expand Up @@ -500,6 +503,8 @@ else if(opcode.equalsIgnoreCase("transformapply")) {
JavaPairRDD<Long, FrameBlock> in = (JavaPairRDD<Long, FrameBlock>) sec.getRDDHandleForFrameObject(fo,
FileFormat.BINARY);
FrameBlock meta = sec.getFrameInput(params.get("meta"));
MatrixBlock embeddings = params.get("embedding") != null ? ec.getMatrixInput(params.get("embedding")) : null;

DataCharacteristics mcIn = sec.getDataCharacteristics(params.get("target"));
DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName());
String[] colnames = !TfMetaUtils.isIDSpec(params.get("spec")) ? in.lookup(1L).get(0)
Expand All @@ -514,20 +519,36 @@ else if(opcode.equalsIgnoreCase("transformapply")) {

// create encoder broadcast (avoiding replication per task)
MultiColumnEncoder encoder = EncoderFactory
.createEncoder(params.get("spec"), colnames, fo.getSchema(), (int) fo.getNumColumns(), meta);
mcOut.setDimension(mcIn.getRows() - ((omap != null) ? omap.getNumRmRows() : 0), encoder.getNumOutCols());
.createEncoder(params.get("spec"), colnames, fo.getSchema(), (int) fo.getNumColumns(), meta, embeddings);
encoder.updateAllDCEncoders();
mcOut.setDimension(mcIn.getRows() - ((omap != null) ? omap.getNumRmRows() : 0),
(int)encoder.getNumOutCols());
Broadcast<MultiColumnEncoder> bmeta = sec.getSparkContext().broadcast(encoder);
Broadcast<TfOffsetMap> bomap = (omap != null) ? sec.getSparkContext().broadcast(omap) : null;

// execute transform apply
JavaPairRDD<Long, FrameBlock> tmp = in.mapToPair(new RDDTransformApplyFunction(bmeta, bomap));
JavaPairRDD<MatrixIndexes, MatrixBlock> out = FrameRDDConverterUtils
.binaryBlockToMatrixBlock(tmp, mcOut, mcOut);
JavaPairRDD<MatrixIndexes, MatrixBlock> out;
Tuple2<Boolean, Integer> aligned = FrameRDDAggregateUtils.checkRowAlignment(in, -1);
if(aligned._1 && mcOut.getCols() <= aligned._2) {
//Blocks are aligned & nr of Col is below Block length (necessary for matrix-matrix reblock)
JavaPairRDD<Long, MatrixBlock> tmp = in.mapToPair(new RDDTransformApplyFunction2(bmeta, bomap));
mcIn.setBlocksize(aligned._2);
mcIn.setDimension(mcIn.getRows(), mcOut.getCols());
JavaPairRDD<MatrixIndexes, MatrixBlock> tmp2 = tmp.mapToPair((PairFunction<Tuple2<Long, MatrixBlock>, MatrixIndexes, MatrixBlock>) in12 ->
new Tuple2<>(new MatrixIndexes(UtilFunctions.computeBlockIndex(in12._1, aligned._2),1), in12._2));
out = RDDConverterUtils.binaryBlockToBinaryBlock(tmp2, mcIn, mcOut);
//out = RDDConverterUtils.matrixBlockToAlignedMatrixBlock(tmp, mcOut, mcOut);
} else {
JavaPairRDD<Long, FrameBlock> tmp = in.mapToPair(new RDDTransformApplyFunction(bmeta, bomap));
out = FrameRDDConverterUtils.binaryBlockToMatrixBlock(tmp, mcOut, mcOut);
}

// set output and maintain lineage/output characteristics
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), params.get("target"));
ec.releaseFrameInput(params.get("meta"));
if(params.get("embedding") != null)
ec.releaseMatrixInput(params.get("embedding"));
}
else if(opcode.equalsIgnoreCase("transformdecode")) {
// get input RDD and meta data
Expand Down Expand Up @@ -908,7 +929,6 @@ public Tuple2<Long, FrameBlock> call(Tuple2<Long, FrameBlock> in) throws Excepti
// execute block transform apply
MultiColumnEncoder encoder = _bencoder.getValue();
MatrixBlock tmp = encoder.apply(blk);

// remap keys
if(_omap != null) {
key = _omap.getValue().getOffset(key);
Expand All @@ -919,6 +939,8 @@ public Tuple2<Long, FrameBlock> call(Tuple2<Long, FrameBlock> in) throws Excepti
}
}



public static class RDDTransformApplyOffsetFunction implements PairFunction<Tuple2<Long, FrameBlock>, Long, Long> {
private static final long serialVersionUID = 3450977356721057440L;

Expand Down Expand Up @@ -955,6 +977,35 @@ public Tuple2<Long, Long> call(Tuple2<Long, FrameBlock> in) throws Exception {
}
}

public static class RDDTransformApplyFunction2 implements PairFunction<Tuple2<Long, FrameBlock>, Long, MatrixBlock> {
private static final long serialVersionUID = 5759813006068230916L;

private Broadcast<MultiColumnEncoder> _bencoder = null;
private Broadcast<TfOffsetMap> _omap = null;

public RDDTransformApplyFunction2(Broadcast<MultiColumnEncoder> bencoder, Broadcast<TfOffsetMap> omap) {
_bencoder = bencoder;
_omap = omap;
}

@Override
public Tuple2<Long, MatrixBlock> call(Tuple2<Long, FrameBlock> in) throws Exception {
long key = in._1();
FrameBlock blk = in._2();

// execute block transform apply
MultiColumnEncoder encoder = _bencoder.getValue();
MatrixBlock tmp = encoder.apply(blk);
// remap keys
if(_omap != null) {
key = _omap.getValue().getOffset(key);
}

// convert to frameblock to reuse frame-matrix reblock
return new Tuple2<>(key, tmp);
}
}

public static class RDDTransformDecodeFunction
implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, Long, FrameBlock> {
private static final long serialVersionUID = -4797324742568170756L;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,77 @@
package org.apache.sysds.runtime.instructions.spark.utils;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import scala.Function3;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
import scala.Tuple5;


public class FrameRDDAggregateUtils
{
public static Tuple2<Boolean, Integer> checkRowAlignment(JavaPairRDD<Long,FrameBlock> in, int blen){
JavaRDD<Tuple5<Boolean, Long, Integer, Integer, Boolean>> row_rdd = in.map((Function<Tuple2<Long, FrameBlock>, Tuple5<Boolean, Long, Integer, Integer, Boolean>>) in1 -> {
long key = in1._1();
FrameBlock blk = in1._2();
return new Tuple5<>(true, key, blen == -1 ? blk.getNumRows() : blen, blk.getNumRows(), true);
});
Tuple5<Boolean, Long, Integer, Integer, Boolean> result = row_rdd.fold(null, (Function2<Tuple5<Boolean, Long, Integer, Integer, Boolean>, Tuple5<Boolean, Long, Integer, Integer, Boolean>, Tuple5<Boolean, Long, Integer, Integer, Boolean>>) (in1, in2) -> {
//easy evaluation
if (in1 == null)
return in2;
if (in2 == null)
return in1;
if (!in1._1() || !in2._1())
return new Tuple5<>(false, null, null, null, null);

//default evaluation
int in1_max = in1._3();
int in1_min = in1._4();
long in1_min_index = in1._2(); //Index of Block with min nr rows --> Block with largest index ( --> last block index)
int in2_max = in2._3();
int in2_min = in2._4();
long in2_min_index = in2._2();

boolean in1_isSingleBlock = in1._5();
boolean in2_isSingleBlock = in2._5();
boolean min_index_comp = in1_min_index > in2_min_index;

if (in1_max == in2_max) {
if (in1_min == in1_max) {
if (in2_min == in2_max)
return new Tuple5<>(true, min_index_comp ? in1_min_index : in2_min_index, in1_max, in1_max, false);
else if (!min_index_comp)
return new Tuple5<>(true, in2_min_index, in1_max, in2_min, false);
//else: in1_min_index > in2_min_index --> in2 is not aligned
} else {
if (in2_min == in2_max)
if (min_index_comp)
return new Tuple5<>(true, in1_min_index, in1_max, in1_min, false);
//else: in1_min_index < in2_min_index --> in1 is not aligned
//else: both contain blocks with less blocks than max
}
} else {
if (in1_max > in2_max && in1_min == in1_max && in2_isSingleBlock && in1_min_index < in2_min_index)
return new Tuple5<>(true, in2_min_index, in1_max, in2_min, false);
/* else:
in1_min != in1_max -> both contain blocks with less blocks than max
!in2_isSingleBlock -> in2 contains at least 2 blocks with less blocks than in1's max
in1_min_index > in2_min_index -> in2's min block != lst block
*/
if (in1_max < in2_max && in2_min == in2_max && in1_isSingleBlock && in2_min_index < in1_min_index)
return new Tuple5<>(true, in1_min_index, in2_max, in1_min, false);
}
return new Tuple5<>(false, null, null, null, null);
});
return new Tuple2<>(result._1(), result._3()) ;
}

public static JavaPairRDD<Long, FrameBlock> mergeByKey( JavaPairRDD<Long, FrameBlock> in )
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.data.DenseBlockFP64DEDUP;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.instructions.spark.data.ReblockBuffer;
import org.apache.sysds.runtime.instructions.spark.data.SerLongWritable;
Expand Down Expand Up @@ -377,6 +379,18 @@ public static void libsvmToBinaryBlock(JavaSparkContext sc, String pathIn,
}
}

//can be removed if not necessary, it's basically the Frame-Matrix reblock but with matrix
public static JavaPairRDD<MatrixIndexes, MatrixBlock> matrixBlockToAlignedMatrixBlock(JavaPairRDD<Long, MatrixBlock> input,
DataCharacteristics mcIn, DataCharacteristics mcOut)
{
//align matrix blocks
JavaPairRDD<MatrixIndexes, MatrixBlock> out = input
.flatMapToPair(new RDDConverterUtils.MatrixBlockToAlignedMatrixBlockFunction(mcIn, mcOut));

//aggregate partial matrix blocks
return RDDAggregateUtils.mergeByKey(out, false);
}

public static JavaPairRDD<LongWritable, Text> stringToSerializableText(JavaPairRDD<Long,String> in)
{
return in.mapToPair(new TextToSerTextFunction());
Expand Down Expand Up @@ -1433,5 +1447,51 @@ public static JavaPairRDD<MatrixIndexes, MatrixBlock> libsvmToBinaryBlock(JavaSp
}
///////////////////////////////
// END LIBSVM FUNCTIONS

private static class MatrixBlockToAlignedMatrixBlockFunction implements PairFlatMapFunction<Tuple2<Long,MatrixBlock>,MatrixIndexes, MatrixBlock> {
private static final long serialVersionUID = -2654986510471835933L;

private DataCharacteristics _mcIn;
private DataCharacteristics _mcOut;
public MatrixBlockToAlignedMatrixBlockFunction(DataCharacteristics mcIn, DataCharacteristics mcOut) {
_mcIn = mcIn; //Frame Characteristics
_mcOut = mcOut; //Matrix Characteristics
}
@Override
public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<Long, MatrixBlock> arg0)
throws Exception
{
long rowIndex = arg0._1();
MatrixBlock blk = arg0._2();
boolean dedup = blk.getDenseBlock() instanceof DenseBlockFP64DEDUP;
ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> ret = new ArrayList<>();
long rlen = _mcIn.getRows();
long clen = _mcIn.getCols();
int blen = _mcOut.getBlocksize();

//slice aligned matrix blocks out of given frame block
long rstartix = UtilFunctions.computeBlockIndex(rowIndex, blen);
long rendix = UtilFunctions.computeBlockIndex(rowIndex+blk.getNumRows()-1, blen);
long cendix = UtilFunctions.computeBlockIndex(blk.getNumColumns(), blen);
for( long rix=rstartix; rix<=rendix; rix++ ) { //for all row blocks
long rpos = UtilFunctions.computeCellIndex(rix, blen, 0);
int lrlen = UtilFunctions.computeBlockSize(rlen, rix, blen);
int fix = (int)((rpos-rowIndex>=0) ? rpos-rowIndex : 0);
int fix2 = (int)Math.min(rpos+lrlen-rowIndex-1,blk.getNumRows()-1);
int mix = UtilFunctions.computeCellInBlock(rowIndex+fix, blen);
int mix2 = mix + (fix2-fix);
for( long cix=1; cix<=cendix; cix++ ) { //for all column blocks
long cpos = UtilFunctions.computeCellIndex(cix, blen, 0);
int lclen = UtilFunctions.computeBlockSize(clen, cix, blen);
MatrixBlock tmp = blk.slice(fix, fix2,
(int)cpos-1, (int)cpos+lclen-2, new MatrixBlock());
MatrixBlock newBlock = new MatrixBlock(lrlen, lclen, false);
ret.add(new Tuple2<>(new MatrixIndexes(rix, cix), newBlock.leftIndexingOperations(tmp, mix, mix2, 0, lclen-1,
new MatrixBlock(), MatrixObject.UpdateType.INPLACE_PINNED)));
}
}
return ret.iterator();
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.sysds.runtime.data.DenseBlockFP64DEDUP;
import org.apache.sysds.runtime.util.UtilFunctions;


Expand Down Expand Up @@ -74,7 +75,17 @@ public void convert(MatrixIndexes k1, MatrixBlock v1) {
{
if(v1.getDenseBlock()==null)
return;
denseArray=v1.getDenseBlockValues();
if(v1.getDenseBlock() instanceof DenseBlockFP64DEDUP){
DenseBlockFP64DEDUP db = (DenseBlockFP64DEDUP) v1.getDenseBlock();
denseArray = new double[v1.rlen*v1.clen];
for (int i = 0; i < v1.rlen; i++) {
double[] row = db.values(i);
for (int j = 0; j < v1.clen; j++) {
denseArray[i*v1.clen + j] = row[j];
}
}
} else
denseArray=v1.getDenseBlockValues();
nextInDenseArray=0;
denseArraySize=v1.getNumRows()*v1.getNumColumns();
}
Expand Down
Loading

0 comments on commit 6bbef0c

Please sign in to comment.