Skip to content

Commit

Permalink
Fix single threaded
Browse files Browse the repository at this point in the history
  • Loading branch information
Baunsgaard committed Aug 17, 2023
1 parent 793ffb9 commit 014c8de
Show file tree
Hide file tree
Showing 18 changed files with 487 additions and 245 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ private AColGroup directCompressDDCMultiCol(IColIndex colIndexes, CompressedSize
final int fill = d.getUpperBoundValue();
d.fill(fill);

final DblArrayCountHashMap map = new DblArrayCountHashMap(Math.max(cg.getNumVals(), 64), colIndexes.size());
final DblArrayCountHashMap map = new DblArrayCountHashMap(Math.max(cg.getNumVals(), 64));
boolean extra;
if(nRow < CompressionSettings.PAR_DDC_THRESHOLD || k == 1)
extra = readToMapDDC(colIndexes, map, d, 0, nRow, fill);
Expand Down Expand Up @@ -625,7 +625,7 @@ private AColGroup compressMultiColSDCFromSparseTransposedBlock(IColIndex cols, i
IColIndex subCols = ColIndexFactory.create(cols.size());
ReaderColumnSelection reader = ReaderColumnSelection.createReader(sub, subCols, false);
final int mapStartSize = Math.min(nrUniqueEstimate, offsetsInt.length / 2);
DblArrayCountHashMap map = new DblArrayCountHashMap(mapStartSize, subCols.size());
DblArrayCountHashMap map = new DblArrayCountHashMap(mapStartSize);

DblArray cellVals = null;
AMapToData data = MapToFactory.create(offsetsInt.length, 257);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,9 @@ public CompressedMatrixBlock updateAndEncode(MatrixBlock mb) {
List<AColGroup> ret = new ArrayList<>(encodings.length);

for(int i = 0; i < encodings.length; i++) {
encodings[i] = encodings[i].update(mb);
ret.add(encodings[i].encode(mb));
Pair<ICLAScheme, AColGroup> p = encodings[i].updateAndEncode(mb);
encodings[i] = p.getKey();
ret.add(p.getValue());
}

return new CompressedMatrixBlock(mb.getNumRows(), mb.getNumColumns(), mb.getNonZeros(), false, ret);
Expand Down Expand Up @@ -285,7 +286,7 @@ public Object call() throws Exception {
for(int j = i; j < e; j++) {
Pair<ICLAScheme, AColGroup> p = encodings[j].updateAndEncode(mb);
encodings[j] = p.getKey();
ret[j] =p.getValue();
ret[j] = p.getValue();
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ protected DDCSchemeMC(ColGroupDDC g) {
final int dictCols = mbDict.getNumColumns();

// Read the mapping data and materialize map.
map = new DblArrayCountHashMap(dictRows * 2, dictCols);
map = new DblArrayCountHashMap(dictRows * 2);
final ReaderColumnSelection r = ReaderColumnSelection.createReader(mbDict, //
ColIndexFactory.create(dictCols), false, 0, dictRows);

Expand All @@ -62,7 +62,7 @@ protected DDCSchemeMC(ColGroupDDC g) {
protected DDCSchemeMC(IColIndex cols) {
super(cols);
final int nCol = cols.size();
this.map = new DblArrayCountHashMap(4, nCol);
this.map = new DblArrayCountHashMap(4);
this.emptyRow = new DblArray(new double[nCol]);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ protected SDCSchemeMC(ASDC g) {
final int dictCols = mbDict.getNumColumns();

// Read the mapping data and materialize map.
map = new DblArrayCountHashMap(dictRows * 2, dictCols);
map = new DblArrayCountHashMap(dictRows * 2);
final ReaderColumnSelection reader = ReaderColumnSelection.createReader(mbDict, //
ColIndexFactory.create(dictCols), false, 0, dictRows);
emptyRow = new DblArray(new double[dictCols]);
Expand Down Expand Up @@ -92,7 +92,7 @@ protected SDCSchemeMC(ASDCZero g) {
final int dictCols = mbDict.getNumColumns();

// Read the mapping data and materialize map.
map = new DblArrayCountHashMap(dictRows * 2, dictCols);
map = new DblArrayCountHashMap(dictRows * 2);
final ReaderColumnSelection r = ReaderColumnSelection.createReader(mbDict, //
ColIndexFactory.create(dictCols), false, 0, dictRows);
DblArray d = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,15 @@ else if(map.getOrDefault(0.0, -1) > nCol / 4) {
return new SparseEncoding(d, o, nCol);
}
else {
map.replaceWithUIDs();
// Create output map
final AMapToData d = MapToFactory.create(nCol, nUnique);

// Iteration 2, make final map
for(int i = off, r = 0; i < end; i++, r++) {
if(!Double.isNaN(vals[i]))
d.set(r, map.get(vals[i]));
d.set(r, map.getId(vals[i]));
else
d.set(r, map.get(0.0));
d.set(r, map.getId(0.0));
}

return new DenseEncoding(d);
Expand All @@ -214,7 +213,6 @@ private static IEncode createFromSparseTransposed(MatrixBlock m, int row) {
}

final int nUnique = map.size();
map.replaceWithUIDs();

final int nCol = m.getNumColumns();
if(nUnique == 0) // only if all NaN
Expand All @@ -227,7 +225,7 @@ else if(alen - apos > nCol / 4) { // return a dense encoding
// only iterate through non zero entries.
for(int i = apos; i < alen; i++)
if(!Double.isNaN(avals[i])) // correction one to assign unique IDs taking into account zero
d.set(aix[i], map.get(avals[i]) + correct);
d.set(aix[i], map.getId(avals[i]) + correct);

// the rest is automatically set to zero.
return new DenseEncoding(d);
Expand Down Expand Up @@ -290,14 +288,14 @@ private static IEncode createFromDense(MatrixBlock m, int col) {
}
else {
// Allocate counts, and iterate once to replace counts with u ids
map.replaceWithUIDs();

final AMapToData d = MapToFactory.create(nRow, nUnique);
// Iteration 2, make final map
for(int i = off, r = 0; i < end; i += nCol, r++)
if(!Double.isNaN(vals[i]))
d.set(r, map.get(vals[i]));
d.set(r, map.getId(vals[i]));
else
d.set(r, map.get(0.0));
d.set(r, map.getId(0.0));
return new DenseEncoding(d);
}
}
Expand Down Expand Up @@ -330,7 +328,6 @@ private static IEncode createFromSparse(MatrixBlock m, int col) {
return new EmptyEncoding();

final int nUnique = map.size();
map.replaceWithUIDs();

final AMapToData d = MapToFactory.create(offsets.size(), nUnique);

Expand All @@ -346,7 +343,7 @@ private static IEncode createFromSparse(MatrixBlock m, int col) {
if(index >= 0) {
final double v = sb.values(r)[index];
if(index >= 0 && !Double.isNaN(v))
d.set(off++, map.get(v));
d.set(off++, map.getId(v));
}
}

Expand All @@ -358,7 +355,7 @@ private static IEncode createFromSparse(MatrixBlock m, int col) {
private static IEncode createWithReader(MatrixBlock m, IColIndex rowCols, boolean transposed) {
final ReaderColumnSelection reader1 = ReaderColumnSelection.createReader(m, rowCols, transposed);
final int nRows = transposed ? m.getNumColumns() : m.getNumRows();
final DblArrayCountHashMap map = new DblArrayCountHashMap(16, rowCols.size());
final DblArrayCountHashMap map = new DblArrayCountHashMap();
final IntArrayList offsets = new IntArrayList();
DblArray cellVals = reader1.nextRow();

Expand All @@ -374,7 +371,6 @@ private static IEncode createWithReader(MatrixBlock m, IColIndex rowCols, boolea
else if(map.size() == 1 && offsets.size() == nRows)
return new ConstEncoding(nRows);

map.replaceWithUIDs();
if(offsets.size() < nRows / 4)
// Output encoded sparse since there is very empty.
return createWithReaderSparse(m, map, rowCols, offsets, nRows, transposed);
Expand All @@ -393,10 +389,10 @@ private static IEncode createWithReaderDense(MatrixBlock m, DblArrayCountHashMap
DblArray cellVals;
if(zero)
while((cellVals = reader2.nextRow()) != null)
d.set(reader2.getCurrentRowIndex(), map.get(cellVals) + 1);
d.set(reader2.getCurrentRowIndex(), map.getId(cellVals) + 1);
else
while((cellVals = reader2.nextRow()) != null)
d.set(reader2.getCurrentRowIndex(), map.get(cellVals));
d.set(reader2.getCurrentRowIndex(), map.getId(cellVals));

return new DenseEncoding(d);
}
Expand All @@ -411,7 +407,7 @@ private static IEncode createWithReaderSparse(MatrixBlock m, DblArrayCountHashMa
int i = 0;
// Iterator 2 of non zero tuples.
while(cellVals != null) {
d.set(i++, map.get(cellVals));
d.set(i++, map.getId(cellVals));
cellVals = reader2.nextRow();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,8 @@ private void writeMultiBlockCompressedSingleThread(MatrixBlock mb, final int rle
final SeparatedGroups s = CLALibSeparator.split(mc.getColGroups());
final CompressedMatrixBlock rmc = new CompressedMatrixBlock(mc.getNumRows(), mc.getNumColumns(),
mc.getNonZeros(), false, s.indexStructures);
// slice out row blocks in this.
List<MatrixBlock> blocks = CLALibSlice.sliceBlocks(rmc, blen, 1); // Slice compressed blocks
for(int b = 0; b < blocks.size(); b++) {
MatrixIndexes index = new MatrixIndexes(b, bc);
CompressedWriteBlock blk = new CompressedWriteBlock(blocks.get(b - 1));
w.append(index, blk);
}
final int nBlocks = rlen / blen + (rlen % blen > 0 ? 1 : 0);
write(w, rmc, bc + 1, 1, nBlocks + 1, blen);

new DictWriteTask(fname, s.dicts, bc).call();

Expand Down Expand Up @@ -336,6 +331,17 @@ private static void cleanup(Path path) throws IOException {
IOUtilFunctions.deleteCrcFilesFromLocalFileSystem(fs, path);
}

private static void write(Writer w, CompressedMatrixBlock rmc, int bc, int bl, int bu, int blen)
throws IOException {
final int nrow = rmc.getNumRows();
for(int b = bl; b < bu; b++) {
MatrixIndexes index = new MatrixIndexes(b, bc);
MatrixBlock cb = CLALibSlice.sliceRowsCompressed(rmc, (b - 1) * blen, Math.min(b * blen, nrow) - 1);
CompressedWriteBlock blk = new CompressedWriteBlock(cb);
w.append(index, blk);
}
}

private class WriteTask implements Callable<Object> {
final int id;
final CompressedMatrixBlock rmc;
Expand All @@ -359,13 +365,7 @@ public Object call() throws Exception {
writerLocks[id].lock();
try {
Writer w = writers[id].get();
final int nrow = rmc.getNumRows();
for(int b = bl; b < bu; b++) {
MatrixIndexes index = new MatrixIndexes(b, bc);
MatrixBlock cb = CLALibSlice.sliceRowsCompressed(rmc, (b - 1) * blen, Math.min(b * blen, nrow) - 1);
CompressedWriteBlock blk = new CompressedWriteBlock(cb);
w.append(index, blk);
}
write(w, rmc, bc, bl, bu, blen);
return null;
}
finally {
Expand Down
37 changes: 10 additions & 27 deletions src/main/java/org/apache/sysds/runtime/compress/utils/ACount.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,12 @@ public abstract class ACount<T> {

public abstract void setNext(ACount<T> e);

public void inc() {
count++;
}

public abstract T key();

public abstract ACount<T> get(T key);

public abstract ACount<T> inc(T key, int c, int id);

protected abstract int hashIndex();

public ACount<T> sort() {
Sorter<T> s = new Sorter<T>();
s.sort(this);
Expand All @@ -47,20 +41,17 @@ public ACount<T> sort() {

@Override
public String toString() {
// String s = super.toString();
StringBuilder sb = new StringBuilder();
sb.append(key().toString());
sb.append("=<");
sb.append(id);
sb.append(",");
sb.append(count);
sb.append(">");

if(next() != null) {
sb.append(" -> ");
sb.append(next().toString());
}

return sb.toString();
}

Expand Down Expand Up @@ -95,11 +86,6 @@ public final DblArray key() {
return key;
}

@Override
protected int hashIndex() {
return key.hashCode();
}

@Override
public ACount<DblArray> get(DblArray key) {
DArrCounts e = this;
Expand All @@ -108,7 +94,7 @@ public ACount<DblArray> get(DblArray key) {
e = e.next;
eq = e.key.equals(key);
}
return e;
return eq ? e : null;
}

@Override
Expand All @@ -121,13 +107,12 @@ public DArrCounts inc(DblArray key, int c, int id) {
}

if(eq) {
count += c;
e.count += c;
return e;
}
else { // e.next is null;
e.next = new DArrCounts(key, id, c);
e.next = new DArrCounts(key, id, c);
return e.next;
// return -c;
}
}

Expand Down Expand Up @@ -164,10 +149,10 @@ public final Double key() {
return key;
}

@Override
protected final int hashIndex() {
return hashIndex(key);
}
// @Override
// protected final int hashIndex() {
// return hashIndex(key);
// }

@Override
public DCounts sort() {
Expand All @@ -185,12 +170,12 @@ public ACount<Double> get(Double key) {
@Override
public DCounts inc(Double key, int c, int id) {
DCounts e = this;
while(e.next != null && e.key != key) {
while(e.next != null && key != e.key) {
e = e.next;
}

if(e.key == key) {
count += c;
if(key == e.key) {
e.count += c;
return e;
}
else { // e.next is null;
Expand All @@ -209,8 +194,6 @@ private static class Sorter<T> {
ACount<T> sorted = null;

private void sort(ACount<T> head) {
if(head == null)
return;
ACount<T> current = head;
ACount<T> prev = null;
ACount<T> next = null;
Expand Down
Loading

0 comments on commit 014c8de

Please sign in to comment.