Skip to content

Commit

Permalink
Spark Read Compressed Matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
Baunsgaard committed Aug 18, 2023
1 parent b3e92f4 commit 11ba58a
Show file tree
Hide file tree
Showing 8 changed files with 234 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ protected AColGroup sliceSingleColumn(int idx) {

@Override
protected AColGroup sliceMultiColumns(int idStart, int idEnd, IColIndex outputCols) {
LOG.error(outputCols);
final IDictionary retDict = _dict.sliceOutColumnRange(idStart, idEnd, _colIndexes.size());
if(retDict == null)
return new ColGroupEmpty(outputCols);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,13 +425,19 @@ private static int longSize(int size) {
return Math.max(size >> 6, 0) + 1;
}

public int getMaxPossible(){
public int getMaxPossible() {
return 2;
}

@Override
public String toString() {
return super.toString() + _size + "[" + _data.length + "]";
StringBuilder sb = new StringBuilder();
sb.append(super.toString());
sb.append("size: " + _size);
sb.append(" longLength:[");
sb.append(_data.length);
sb.append("]");
return sb.toString();

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@

public class DictWritable implements Writable {

List<IDictionary> dicts;
public List<IDictionary> dicts;

public DictWritable(){

}

protected DictWritable(List<IDictionary> dicts) {
this.dicts = dicts;
Expand Down Expand Up @@ -45,7 +49,11 @@ public String toString() {
}

public static class K implements Writable {
int id;
public int id;

public K(){

}

public K(int id) {
this.id = id;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,13 @@ private static MatrixBlock readCompressedMatrix(String fname, JobConf job, FileS
data.putAll(readColumnGroups(subPath, job));
}

final Map<Integer, List<IDictionary>> dicts = new HashMap<>();
for(Path subPath : IOUtilFunctions.getSequenceFilePaths(fs, new Path(fname + ".dict"))) {
dicts.putAll(readDictionaries(subPath, job));
final Path dictPath = new Path(fname + ".dict");
Map<Integer, List<IDictionary>> dicts = null;
if(fs.exists(dictPath)) {
dicts = new HashMap<>();
for(Path subPath : IOUtilFunctions.getSequenceFilePaths(fs, dictPath)) {
dicts.putAll(readDictionaries(subPath, job));
}
}

if(data.containsValue(null))
Expand Down Expand Up @@ -130,7 +134,7 @@ private static Map<Integer, List<IDictionary>> readDictionaries(Path path, JobCo
// Use write and read interface to read and write this object.
DictWritable.K key = new DictWritable.K(0);
DictWritable value = new DictWritable(null);
while(reader.next(key, value))
while(reader.next(key, value))
data.put(key.id, value.dicts);
}
finally {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package org.apache.sysds.runtime.compress.io;

import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.mapred.SequenceFileInputFormat;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ADictBasedColGroup;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;

import scala.Tuple2;

public interface ReaderSparkCompressed {
public static final Log LOG = LogFactory.getLog(ReaderSparkCompressed.class.getName());

@SuppressWarnings("unchecked")
public static JavaPairRDD<MatrixIndexes, MatrixBlock> getRDD(JavaSparkContext sc, String fileName) {
JavaPairRDD<MatrixIndexes, CompressedWriteBlock> cmbrdd = sc.hadoopFile(fileName, SequenceFileInputFormat.class,
MatrixIndexes.class, CompressedWriteBlock.class);
JavaPairRDD<DictWritable.K, DictWritable> dictsRdd = sc.hadoopFile(fileName + ".dict",
SequenceFileInputFormat.class, DictWritable.K.class, DictWritable.class);

return combineRdds(cmbrdd, dictsRdd);
}

private static JavaPairRDD<MatrixIndexes, MatrixBlock> combineRdds(
JavaPairRDD<MatrixIndexes, CompressedWriteBlock> cmbRdd, JavaPairRDD<DictWritable.K, DictWritable> dictsRdd) {
// combine the elements
JavaPairRDD<MatrixIndexes, MatrixBlock> mbrdd = cmbRdd.mapValues(new CompressUnwrap());
JavaPairRDD<Integer, List<IDictionary>> dictsUnpacked = dictsRdd
.mapToPair((t) -> new Tuple2<>(Integer.valueOf(t._1.id + 1), t._2.dicts));
JavaPairRDD<Integer, Tuple2<MatrixIndexes, MatrixBlock>> mbrddC = mbrdd
.mapToPair((t) -> new Tuple2<>(Integer.valueOf((int) t._1.getColumnIndex()), t));

JavaPairRDD<Integer, Tuple2<Tuple2<MatrixIndexes, MatrixBlock>, List<IDictionary>>> j = mbrddC
.join(dictsUnpacked);

JavaPairRDD<MatrixIndexes, MatrixBlock> ret = j.mapToPair(ReaderSparkCompressed::combineTuples);
return ret;
}

private static Tuple2<MatrixIndexes, MatrixBlock> combineTuples(
Tuple2<Integer, Tuple2<Tuple2<MatrixIndexes, MatrixBlock>, List<IDictionary>>> e) {
MatrixIndexes kOut = e._2._1._1;
MatrixBlock mbIn = e._2._1._2;
List<IDictionary> dictsIn = e._2._2;
MatrixBlock ob = combineMatrixBlockAndDict(mbIn, dictsIn);
return new Tuple2<>(kOut, ob);
}

private static MatrixBlock combineMatrixBlockAndDict(MatrixBlock mb, List<IDictionary> dicts) {
if(mb instanceof CompressedMatrixBlock) {
CompressedMatrixBlock cmb = (CompressedMatrixBlock) mb;
List<AColGroup> gs = cmb.getColGroups();

if(dicts.size() == gs.size()) {
for(int i = 0; i < dicts.size(); i++)
gs.set(i, ((ADictBasedColGroup) gs.get(i)).copyAndSet(dicts.get(i)));
}
else {
LOG.error(dicts.size());
LOG.error(gs.size());
int gis = 0;
for(int i = 0; i < gs.size(); i++) {
AColGroup g = gs.get(i);
if(g instanceof ADictBasedColGroup) {
ADictBasedColGroup dg = (ADictBasedColGroup) g;
gs.set(i, dg.copyAndSet(dicts.get(gis)));
gis++;
}
}
}

return new CompressedMatrixBlock(cmb.getNumRows(), cmb.getNumColumns(), cmb.getNonZeros(),
cmb.isOverlapping(), gs);
}
else
return mb;
}
}
124 changes: 39 additions & 85 deletions src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java
Original file line number Diff line number Diff line change
Expand Up @@ -109,73 +109,14 @@ private static long findCLength(Map<MatrixIndexes, MatrixBlock> m, MatrixBlock b

private static MatrixBlock combine(final Map<MatrixIndexes, MatrixBlock> m, Map<Integer, List<IDictionary>> d,
final MatrixIndexes lookup, final int rlen, final int clen, final int blen, final int k) {

// if(rlen < blen) // Shortcut, in case file only contains one block in r length.
// return combineColumnGroups(m, d, lookup, rlen, clen, blen, k);

// final CompressionType[] colTypes = new CompressionType[clen];
// // Look through the first blocks in to the top.
// for(int bc = 0; bc * blen < clen; bc++) {
// lookup.setIndexes(1, bc + 1); // get first blocks
// final MatrixBlock b = m.get(lookup);
// if(!(b instanceof CompressedMatrixBlock)) {
// LOG.warn("Found uncompressed matrix in Map of matrices, this is not"
// + " supported in combine therefore falling back to decompression");
// return combineViaDecompression(m, rlen, clen, blen, k);
// }
// final CompressedMatrixBlock cmb = (CompressedMatrixBlock) b;
// if(cmb.isOverlapping()) {
// LOG.warn("Not supporting overlapping combine yet falling back to decompression");
// return combineViaDecompression(m, rlen, clen, blen, k);
// }
// final List<AColGroup> gs = cmb.getColGroups();
// final int off = bc * blen;
// for(AColGroup g : gs) {
// try {
// final IIterate cols = g.getColIndices().iterator();
// final CompressionType t = g.getCompType();
// while(cols.hasNext())
// colTypes[cols.next() + off] = t;
// }
// catch(Exception e) {
// throw new DMLCompressionException("Failed combining: " + g.toString());
// }
// }
// }

// // Look through the Remaining blocks down in the rows.
// for(int br = 1; br * blen < rlen; br++) {
// for(int bc = 0; bc * blen < clen; bc++) {
// lookup.setIndexes(br + 1, bc + 1); // get first blocks
// final MatrixBlock b = m.get(lookup);
// if(!(b instanceof CompressedMatrixBlock)) {
// LOG.warn("Found uncompressed matrix in Map of matrices, this is not"
// + " supported in combine therefore falling back to decompression");
// return combineViaDecompression(m, rlen, clen, blen, k);
// }
// final CompressedMatrixBlock cmb = (CompressedMatrixBlock) b;
// if(cmb.isOverlapping()) {
// LOG.warn("Not supporting overlapping combine yet falling back to decompression");
// return combineViaDecompression(m, rlen, clen, blen, k);
// }
// final List<AColGroup> gs = cmb.getColGroups();
// final int off = bc * blen;
// for(AColGroup g : gs) {
// final IIterate cols = g.getColIndices().iterator();
// final CompressionType t = g.getCompType();
// while(cols.hasNext()) {
// final int c = cols.next();
// if(colTypes[c + off] != t) {
// LOG.warn("Not supported different types of column groups to combine."
// + "Falling back to decompression of all blocks " + t + " vs " + colTypes[c + off]);
// return combineViaDecompression(m, rlen, clen, blen, k);
// }
// }
// }
// }
// }

return combineColumnGroups(m, d, lookup, rlen, clen, blen, k);
try {
return combineColumnGroups(m, d, lookup, rlen, clen, blen, k);
}
catch(Exception e) {
// throw new RuntimeException("failed normal combine", e);
LOG.error("Failed to combine compressed blocks, fallback to decompression.", e);
return combineViaDecompression(m, rlen, clen, blen, k);
}
}

private static MatrixBlock combineViaDecompression(final Map<MatrixIndexes, MatrixBlock> m, final int rlen,
Expand All @@ -201,35 +142,49 @@ private static MatrixBlock combineColumnGroups(final Map<MatrixIndexes, MatrixBl
Map<Integer, List<IDictionary>> d, final MatrixIndexes lookup, final int rlen, final int clen, final int blen,
final int k) {

final AColGroup[][] finalCols = new AColGroup[clen][]; // temp array for combining
final int blocksInColumn = (rlen - 1) / blen + 1;
int nGroups = 0;
for(int bc = 0; bc * blen < clen; bc++) {
// iterate through the first row of blocks to see number of columngroups.
lookup.setIndexes(1, bc + 1);
MatrixBlock mb = m.get(lookup);
if(!(mb instanceof CompressedMatrixBlock)) {
LOG.warn("Combining via decompression. There was an uncompressed MatrixBlock");
return combineViaDecompression(m, rlen, clen, blen, k);
}

final CompressedMatrixBlock cmb = (CompressedMatrixBlock) m.get(lookup);
final List<AColGroup> gs = cmb.getColGroups();
nGroups += gs.size();
}

final int blocksInColumn = rlen / blen + (rlen % blen > 0 ? 1 : 0);
final AColGroup[][] finalCols = new AColGroup[nGroups][blocksInColumn]; // temp array for combining

// LOG.error(m);
// Add all the blocks into linear structure.
for(int br = 0; br * blen < rlen; br++) {
int cgid = 0;
for(int bc = 0; bc * blen < clen; bc++) {
lookup.setIndexes(br + 1, bc + 1);
final CompressedMatrixBlock cmb = (CompressedMatrixBlock) m.get(lookup);
for(AColGroup g : cmb.getColGroups()) {
final List<AColGroup> gs = cmb.getColGroups();

for(int i = 0; i < gs.size(); i++) {
AColGroup g = gs.get(i);
final AColGroup gc = bc > 0 ? g.shiftColIndices(bc * blen) : g;
final int c = gc.getColIndices().get(0);
// LOG.error(c);
if(br == 0)
finalCols[c] = new AColGroup[blocksInColumn];
else if(finalCols[c] == null) {
LOG.warn("Combining via decompression. There was an column"
+ " assigned not assigned in block 1 indicating spark compression");
return combineViaDecompression(m, rlen, clen, blen, k);
}
finalCols[c][br] = gc;
if(br != 0 && (finalCols[c][0] == null ||
!finalCols[c][br].getColIndices().equals(finalCols[c][0].getColIndices()))) {

finalCols[cgid][br] = gc;
if(br != 0 && (finalCols[cgid][0] == null ||
!finalCols[cgid][br].getColIndices().equals(finalCols[cgid][0].getColIndices()))) {
LOG.warn("Combining via decompression. There was an column with different index");
return combineViaDecompression(m, rlen, clen, blen, k);
}
cgid++;

}
}
if(cgid != finalCols.length) {
LOG.warn("Combining via decompression. The number of columngroups in each block is not identical");
return combineViaDecompression(m, rlen, clen, blen, k);
}
}

final ExecutorService pool = CommonThreadPool.get();
Expand All @@ -238,7 +193,6 @@ else if(finalCols[c] == null) {
List<AColGroup> finalGroups = pool.submit(() -> {
return Arrays//
.stream(finalCols)//
.filter(x -> x != null)// filter all columns that are contained in other groups.
.parallel()//
.map(x -> {
return combineN(x);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@
import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.io.CompressUnwrap;
import org.apache.sysds.runtime.compress.io.CompressedWriteBlock;
import org.apache.sysds.runtime.compress.io.ReaderSparkCompressed;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
Expand Down Expand Up @@ -461,14 +460,20 @@ else if( mo.isDirty() || mo.isCached(false) || mo.isFederated() || mo instanceof
//recordreader returns; the javadoc explicitly recommend to copy all key/value pairs
// cp is workaround for read bug
rdd = SparkUtils.copyBinaryBlockMatrix((JavaPairRDD<MatrixIndexes, MatrixBlock>)rdd);
else if(fmt == FileFormat.COMPRESSED)
rdd = ((JavaPairRDD<MatrixIndexes, CompressedWriteBlock>) rdd).mapValues(new CompressUnwrap());
else if(fmt.isTextFormat())
else if(fmt == FileFormat.COMPRESSED){
// initial RDDS.
rdd = ReaderSparkCompressed.getRDD(sc, mo.getFileName());
}
else if(fmt.isTextFormat()){
JavaPairRDD<LongWritable, Text> textRDD = sc.hadoopFile(mo.getFileName(), //
inputInfo.inputFormatClass, LongWritable.class, Text.class);
// cp is workaround for read bug
rdd = ((JavaPairRDD<LongWritable, Text>) rdd).mapToPair(new CopyTextInputFunction());
rdd = textRDD.mapToPair(new CopyTextInputFunction());
}
else
throw new DMLRuntimeException("Incorrect input format in getRDDHandleForVariable");


//keep rdd handle for future operations on it
RDDObject rddhandle = new RDDObject(rdd);
rddhandle.setHDFSFile(true);
Expand Down
Loading

0 comments on commit 11ba58a

Please sign in to comment.