Skip to content

Commit

Permalink
improvements (see pr)
Browse files Browse the repository at this point in the history
  • Loading branch information
Baunsgaard committed Aug 16, 2023
1 parent f15f596 commit 793ffb9
Show file tree
Hide file tree
Showing 17 changed files with 237 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ public void set(int n, Integer v) {
*
* @param n index to set.
* @param v the value to set it to.
* @return v as encoded, note this value can be different that the one put in if the map is not able to represent the
* value
* @return v as encoded, note this value can be different that the one put in if the map is not able to represent
* the value
*/
public abstract int setAndGet(int n, int v);

Expand Down Expand Up @@ -327,7 +327,8 @@ protected void preAggregateDenseMultiRowContiguousBy1(double[] mV, int nCol, int
* @param cu The column in m to end at (not inclusive)
* @param indexes The Offset Indexes to iterate through
*/
public final void preAggregateDense(MatrixBlock m, double[] preAV, int rl, int ru, int cl, int cu, AOffset indexes) {
public final void preAggregateDense(MatrixBlock m, double[] preAV, int rl, int ru, int cl, int cu,
AOffset indexes) {
indexes.preAggregateDenseMap(m, preAV, rl, ru, cl, cu, getUnique(), this);
}

Expand Down Expand Up @@ -449,9 +450,9 @@ protected void preAggregateDDC_DDCMultiCol(AMapToData tm, IDictionary td, double

for(int r = h; r < sz; r += 8) {
int r2 = r + 1, r3 = r + 2, r4 = r + 3, r5 = r + 4, r6 = r + 5, r7 = r + 6, r8 = r + 7;
td.addToEntryVectorized(v, tm.getIndex(r), tm.getIndex(r2), tm.getIndex(r3), tm.getIndex(r4), tm.getIndex(r5),
tm.getIndex(r6), tm.getIndex(r7), tm.getIndex(r8), getIndex(r), getIndex(r2), getIndex(r3), getIndex(r4),
getIndex(r5), getIndex(r6), getIndex(r7), getIndex(r8), nCol);
td.addToEntryVectorized(v, tm.getIndex(r), tm.getIndex(r2), tm.getIndex(r3), tm.getIndex(r4),
tm.getIndex(r5), tm.getIndex(r6), tm.getIndex(r7), tm.getIndex(r8), getIndex(r), getIndex(r2),
getIndex(r3), getIndex(r4), getIndex(r5), getIndex(r6), getIndex(r7), getIndex(r8), nCol);
}
}

Expand Down Expand Up @@ -574,8 +575,8 @@ private int preAggregateSDCZ_DDCMultiCol_vect(AMapToData tm, IDictionary td, dou
final int h = size % 8;
int i = 0;
while(i < size - h) {
int t1 = getIndex(i), t2 = getIndex(i + 1), t3 = getIndex(i + 2), t4 = getIndex(i + 3), t5 = getIndex(i + 4),
t6 = getIndex(i + 5), t7 = getIndex(i + 6), t8 = getIndex(i + 7);
int t1 = getIndex(i), t2 = getIndex(i + 1), t3 = getIndex(i + 2), t4 = getIndex(i + 3),
t5 = getIndex(i + 4), t6 = getIndex(i + 5), t7 = getIndex(i + 6), t8 = getIndex(i + 7);

int f1 = it.value(), f2 = it.next(), f3 = it.next(), f4 = it.next(), f5 = it.next(), f6 = it.next(),
f7 = it.next(), f8 = it.next();
Expand Down Expand Up @@ -604,7 +605,8 @@ public final void preAggregateSDCZ_SDCZ(AMapToData tm, IDictionary td, AOffset t
preAggregateSDCZ_SDCZMultiCol(tm, td, tof, of, ret.getValues(), nCol);
}

private final void preAggregateSDCZ_SDCZSingleCol(AMapToData tm, double[] td, AOffset tof, AOffset of, double[] dv) {
private final void preAggregateSDCZ_SDCZSingleCol(AMapToData tm, double[] td, AOffset tof, AOffset of,
double[] dv) {
final AOffsetIterator itThat = tof.getOffsetIterator();
final AOffsetIterator itThis = of.getOffsetIterator();
final int tSize = tm.size() - 1, size = size() - 1;
Expand Down Expand Up @@ -767,7 +769,7 @@ public void copy(AMapToData d) {
if(d.nUnique == 1)
return;
// else if(d instanceof MapToBit)
// copyBit((MapToBit) d);
// copyBit((MapToBit) d);
else if(d instanceof MapToInt)
copyInt((MapToInt) d);
else {
Expand All @@ -782,7 +784,7 @@ protected void copyInt(MapToInt d) {
}

// protected void copyBit(MapToBit d) {
// copyBitLong(d.getData());
// copyBitLong(d.getData());
// }

public abstract void copyInt(int[] d);
Expand All @@ -800,6 +802,13 @@ public int getMax() {
return m;
}

/**
* Get the maximum possible value to encode in this encoding. For instance in a bit you can encode 2 values
*
* @return The maximum number of distict values to encode
*/
public abstract int getMaxPossible();

public abstract AMapToData resize(int unique);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,10 @@ private static int longSize(int size) {
return Math.max(size >> 6, 0) + 1;
}

public int getMaxPossible(){
return 2;
}

@Override
public String toString() {
return super.toString() + _size + "[" + _data.length + "]";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,11 @@ public AMapToData appendN(IMapToDataGroup[] d) {
return new MapToByte(getUnique(), ret);
}

@Override
public int getMaxPossible(){
return 256;
}

@Override
public boolean equals(AMapToData e) {
return e instanceof MapToByte && //
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,11 @@ public AMapToData appendN(IMapToDataGroup[] d) {
return new MapToChar(getUnique(), ret);
}

@Override
public int getMaxPossible(){
return Character.MAX_VALUE;
}

@Override
public boolean equals(AMapToData e) {
return e instanceof MapToChar && //
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,15 @@ public AMapToData appendN(IMapToDataGroup[] d) {
throw new NotImplementedException();
}

@Override
public int getMaxPossible() {
return (int) Character.MAX_VALUE * 256;
}

@Override
public boolean equals(AMapToData e) {
return e instanceof MapToCharPByte && //
e.getUnique() == getUnique() &&//
e.getUnique() == getUnique() && //
Arrays.equals(((MapToCharPByte) e)._data_b, _data_b) && //
Arrays.equals(((MapToCharPByte) e)._data_c, _data_c);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ public AMapToData appendN(IMapToDataGroup[] d) {
return new MapToInt(getUnique(), ret);
}

@Override
public int getMaxPossible(){
return Integer.MAX_VALUE;
}

@Override
public boolean equals(AMapToData e) {
return e instanceof MapToInt && //
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ public AMapToData appendN(IMapToDataGroup[] d) {
return new MapToZero(p);
}

@Override
public int getMaxPossible(){
return 1;
}

@Override
public boolean equals(AMapToData e) {
return e instanceof MapToZero && //
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.Pair;

public abstract class ACLAScheme implements ICLAScheme {
protected final IColIndex cols;
Expand All @@ -44,7 +45,6 @@ public ICLAScheme update(MatrixBlock data) {
return update(data, getColIndices());
}


@Override
public ICLAScheme updateSparse(MatrixBlock data) {
// fallback to default
Expand All @@ -63,7 +63,6 @@ public ICLAScheme updateGeneric(MatrixBlock data) {
return updateGeneric(data, getColIndices());
}


@Override
public ICLAScheme updateSparse(MatrixBlock data, IColIndex columns) {
// fallback to default
Expand All @@ -82,16 +81,39 @@ public ICLAScheme updateGeneric(MatrixBlock data, IColIndex columns) {
return update(data, columns);
}

@Override
public Pair<ICLAScheme, AColGroup> updateAndEncode(MatrixBlock data) {
return updateAndEncode(data, getColIndices());
}

@Override
public Pair<ICLAScheme, AColGroup> updateAndEncode(MatrixBlock data, IColIndex columns) {
// try {
return tryUpdateAndEncode(data, columns);
// }
// catch(Exception e) {
// ICLAScheme s = update(data, columns);
// AColGroup g = encode(data, columns);
// return new Pair<>(s, g);
// }

}

protected Pair<ICLAScheme, AColGroup> tryUpdateAndEncode(MatrixBlock data, IColIndex columns) {
ICLAScheme s = update(data, columns);
AColGroup g = encode(data, columns);
return new Pair<>(s, g);
}

protected final void validate(MatrixBlock data, IColIndex columns) throws IllegalArgumentException {
if(columns.size() != cols.size())
throw new IllegalArgumentException(
"Invalid number of columns to encode expected: " + cols.size() + " but got: " + columns.size());

final int nCol = data.getNumColumns();
if(nCol < cols.get(cols.size() - 1))
throw new IllegalArgumentException("Invalid columns to encode with max col:" + nCol+ " list of columns: "+ columns);
throw new IllegalArgumentException(
"Invalid columns to encode with max col:" + nCol + " list of columns: " + columns);
}



}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.Pair;
import org.apache.sysds.runtime.util.CommonThreadPool;

/**
Expand Down Expand Up @@ -280,14 +281,11 @@ protected UpdateAndEncodeTask(int i, int e, AColGroup[] ret, MatrixBlock mb) {

@Override
public Object call() throws Exception {
final boolean dense = mb.getDenseBlock().isContiguous();

for(int j = i; j < e; j++) {
if(dense)
encodings[j] = encodings[j].updateDense(mb);
else
encodings[j] = encodings[j].update(mb);
ret[j] = encodings[j].encode(mb);
Pair<ICLAScheme, AColGroup> p = encodings[j].updateAndEncode(mb);
encodings[j] = p.getKey();
ret[j] =p.getValue();
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.sysds.runtime.compress.colgroup.scheme;

import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
Expand All @@ -31,6 +32,7 @@
import org.apache.sysds.runtime.compress.utils.DblArray;
import org.apache.sysds.runtime.compress.utils.DblArrayCountHashMap;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.Pair;

public class DDCSchemeMC extends DDCScheme {

Expand All @@ -53,6 +55,7 @@ protected DDCSchemeMC(ColGroupDDC g) {
DblArray d = null;
while((d = r.nextRow()) != null)
map.increment(d);

emptyRow = new DblArray(new double[dictCols]);
}

Expand Down Expand Up @@ -87,7 +90,7 @@ public ICLAScheme update(MatrixBlock data, IColIndex columns) {
}

if(r < nRow)
map.increment(emptyRow, nRow - r - 1);
map.increment(emptyRow, nRow - r - 1);

return this;
}
Expand All @@ -103,15 +106,15 @@ public AColGroup encode(MatrixBlock data, IColIndex columns) {

DblArray cellVals;
ACount<DblArray> emptyIdx = map.getC(emptyRow);
if(emptyIdx == null){
if(emptyIdx == null) {

while((cellVals = reader.nextRow()) != null) {
final int row = reader.getCurrentRowIndex();
final int id = map.getId(cellVals);
d.set(row, id);
}
}
else{
else {
int r = 0;
while((cellVals = reader.nextRow()) != null) {
final int row = reader.getCurrentRowIndex();
Expand All @@ -131,4 +134,54 @@ public AColGroup encode(MatrixBlock data, IColIndex columns) {
return ColGroupDDC.create(columns, lastDict, d, null);
}

@Override
protected Pair<ICLAScheme, AColGroup> tryUpdateAndEncode(MatrixBlock data, IColIndex columns) {

validate(data, columns);
final int nRow = data.getNumRows();
final ReaderColumnSelection reader = ReaderColumnSelection.createReader(//
data, columns, false, 0, nRow);
final AMapToData d = MapToFactory.create(nRow, map.size());
int max = d.getMaxPossible();

DblArray cellVals;
ACount<DblArray> emptyIdx = map.getC(emptyRow);
if(emptyIdx == null) {

while((cellVals = reader.nextRow()) != null) {
final int row = reader.getCurrentRowIndex();
final int id = map.increment(cellVals);
if(id >= max)
throw new DMLCompressionException("Failed update and encode with " + max + " possible values");
d.set(row, id);
}
}
else {
int r = 0;
while((cellVals = reader.nextRow()) != null) {
final int row = reader.getCurrentRowIndex();
if(row != r) {
map.increment(emptyRow, row - r);
while(r < row)
d.set(r++, emptyIdx.id);
}
final int id = map.increment(cellVals);
if(id >= max)
throw new DMLCompressionException("Failed update and encode with " + max + " possible values" + map + " " + map.size());
d.set(row, id);
r++;
}
if(r < nRow)

map.increment(emptyRow, nRow - r);
while(r < nRow)
d.set(r++, emptyIdx.id);
}
if(lastDict == null || lastDict.getNumberOfValues(columns.size()) != map.size())
lastDict = DictionaryFactory.create(map, columns.size(), false, data.getSparsity());

AColGroup g = ColGroupDDC.create(columns, lastDict, d, null);
ICLAScheme s = this;
return new Pair<>(s, g);
}
}
Loading

0 comments on commit 793ffb9

Please sign in to comment.