Skip to content

Commit

Permalink
compress with NaNs
Browse files Browse the repository at this point in the history
  • Loading branch information
Baunsgaard committed Aug 19, 2023
1 parent 5fd8322 commit 73410bf
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ protected AColGroup sliceSingleColumn(int idx) {

@Override
public boolean containsValue(double pattern) {

if(Double.isNaN(pattern) || Double.isInfinite(pattern))
return ColGroupUtils.containsInfOrNan(pattern, _reference) || _dict.containsValue(pattern);
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import org.apache.sysds.runtime.compress.colgroup.functional.LinearRegression;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.indexes.IIterate;
import org.apache.sysds.runtime.compress.colgroup.insertionsort.AInsertionSorter;
import org.apache.sysds.runtime.compress.colgroup.insertionsort.InsertionSorterFactory;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
Expand Down Expand Up @@ -226,12 +225,12 @@ private void logEstVsActual(double time, AColGroup act, CompressedSizeInfoColGro
if(estC < actC * 0.75) {
String warning = "The estimate cost is significantly off : " + est;
LOG.debug(
String.format("time[ms]: %10.2f %25s est %10.0f -- act %10.0f distinct:%5d cols:%s wanted:%s\n\t\t%s", time,
retType, estC, actC, act.getNumValues(), cols, wanted, warning));
String.format("time[ms]: %10.2f %25s est %10.0f -- act %10.0f distinct:%5d cols:%s wanted:%s\n\t\t%s",
time, retType, estC, actC, act.getNumValues(), cols, wanted, warning));
}
else {
LOG.debug(String.format("time[ms]: %10.2f %25s est %10.0f -- act %10.0f distinct:%5d cols:%s wanted:%s", time,
retType, estC, actC, act.getNumValues(), cols, wanted));
LOG.debug(String.format("time[ms]: %10.2f %25s est %10.0f -- act %10.0f distinct:%5d cols:%s wanted:%s",
time, retType, estC, actC, act.getNumValues(), cols, wanted));
}

}
Expand All @@ -258,7 +257,7 @@ private AColGroup compress(CompressedSizeInfoColGroup cg) throws Exception {
final boolean t = cs.transposed;

// Fast path compressions
if(ct == CompressionType.EMPTY && (!t || isAllNanTransposed(cg)))
if(ct == CompressionType.EMPTY && !t)
return new ColGroupEmpty(colIndexes);
else if(ct == CompressionType.UNCOMPRESSED) // don't construct mapping if uncompressed
return ColGroupUncompressed.create(colIndexes, in, t);
Expand Down Expand Up @@ -321,7 +320,6 @@ private AColGroup directCompressDDCSingleCol(IColIndex colIndexes, CompressedSiz
final AMapToData d = MapToFactory.create(nRow, Math.max(Math.min(cg.getNumOffs() + 1, nRow), 126));
final DoubleCountHashMap map = new DoubleCountHashMap(cg.getNumVals());


// unlike multi-col no special handling of zero entries are needed.
if(cs.transposed)
readToMapDDCTransposed(col, map, d);
Expand Down Expand Up @@ -418,7 +416,7 @@ else if(in.getDenseBlock().isContiguous()) {
}
else {
final DenseBlock db = in.getDenseBlock();
for(int r = 0; r < nRow; r++){
for(int r = 0; r < nRow; r++) {
final double[] dv = db.values(r);
int off = db.pos(r) + col;
data.set(r, map.increment(dv[off]));
Expand Down Expand Up @@ -657,10 +655,7 @@ private AColGroup compressSingleColSDCFromSparseTransposedBlock(IColIndex cols,

// count distinct items frequencies
for(int j = apos; j < alen; j++)
if(!Double.isNaN(vals[j]))
map.increment(vals[j]);
else
map.increment(0.0);
map.increment(vals[j]);

ACount<Double>[] entries = map.extractValues();
Arrays.sort(entries, Comparator.comparing(x -> -x.count));
Expand All @@ -682,10 +677,7 @@ private AColGroup compressSingleColSDCFromSparseTransposedBlock(IColIndex cols,
else {
final AMapToData mapToData = MapToFactory.create((alen - apos), entries.length);
for(int j = apos; j < alen; j++)
if(!Double.isNaN(vals[j]))
mapToData.set(j - apos, map.get(vals[j]));
else
mapToData.set(j - apos, map.get(0.0));
mapToData.set(j - apos, map.get(vals[j]));
return ColGroupSDCZeros.create(cols, nRow, Dictionary.create(dict), offsets, mapToData, counts);
}
}
Expand Down Expand Up @@ -721,58 +713,6 @@ else if(entries.length == 1) {
}
}

private boolean isAllNanTransposed(CompressedSizeInfoColGroup cg) {
final IColIndex cols = cg.getColumns();
return in.isInSparseFormat() ? isAllNanTransposedSparse(cols) : isAllNanTransposedDense(cols);
}

private boolean isAllNanTransposedSparse(IColIndex cols) {
SparseBlock sb = in.getSparseBlock();
IIterate it = cols.iterator();
while(it.hasNext()) {
int c = it.next();
if(sb.isEmpty(c))
continue;
double[] vl = sb.values(c);
for(double v : vl) {
if(!Double.isNaN(v))
return false;
}
}
return true;
}

private boolean isAllNanTransposedDense(IColIndex cols) {
if(in.getDenseBlock().isContiguous()) {
double[] vals = in.getDenseBlockValues();
IIterate it = cols.iterator();
while(it.hasNext()) {
int c = it.next();
int off = c * nRow;
for(int r = 0; r < nRow; r++) {
if(!Double.isNaN(vals[off + r])) {
return false;
}
}
}
return true;
}
else {
DenseBlock db = in.getDenseBlock();
IIterate it = cols.iterator();
while(it.hasNext()) {
int c = it.next();
double[] vals = db.values(c);
int off = db.pos(c);
for(int r = 0; r < nRow; r++) {
if(!Double.isNaN(vals[off + r]))
return false;
}
}
return true;
}
}

private class CompressTask implements Callable<Object> {

private final List<CompressedSizeInfoColGroup> _groups;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ else if(map.getOrDefault(0.0, -1) > nCol / 4) {
}
}
if(di != nV)
throw new RuntimeException("Did not find equal number of elements " + di + " vs " + nV );
throw new RuntimeException("Did not find equal number of elements " + di + " vs " + nV);

final AOffset o = OffsetFactory.createOffset(offsets);
return new SparseEncoding(d, o, nCol);
Expand All @@ -180,12 +180,8 @@ else if(map.getOrDefault(0.0, -1) > nCol / 4) {
final AMapToData d = MapToFactory.create(nCol, nUnique);

// Iteration 2, make final map
for(int i = off, r = 0; i < end; i++, r++) {
if(!Double.isNaN(vals[i]))
d.set(r, map.getId(vals[i]));
else
d.set(r, map.getId(0.0));
}
for(int i = off, r = 0; i < end; i++, r++)
d.set(r, map.getId(vals[i]));

return new DenseEncoding(d);
}
Expand All @@ -203,16 +199,13 @@ private static IEncode createFromSparseTransposed(MatrixBlock m, int row) {

// Iteration 1 of non zero values, make Count HashMap.
for(int i = apos; i < alen; i++) {
// sequential of non zero cells.
if(!Double.isNaN(avals[i]))
map.increment(avals[i]);

map.increment(avals[i]);
}

final int nUnique = map.size();

final int nCol = m.getNumColumns();
if(nUnique == 0) // only if all NaN
if(nUnique == 0)
return new EmptyEncoding();
else if(alen - apos > nCol / 4) { // return a dense encoding
// If the row was full but the overall matrix is sparse.
Expand All @@ -221,8 +214,7 @@ else if(alen - apos > nCol / 4) { // return a dense encoding
// Since the dictionary is allocated with zero then we exploit that here and
// only iterate through non zero entries.
for(int i = apos; i < alen; i++)
if(!Double.isNaN(avals[i])) // correction one to assign unique IDs taking into account zero
d.set(aix[i], map.getId(avals[i]) + correct);
d.set(aix[i], map.getId(avals[i]) + correct);

// the rest is automatically set to zero.
return new DenseEncoding(d);
Expand All @@ -233,8 +225,7 @@ else if(alen - apos > nCol / 4) { // return a dense encoding

// Iteration 2 of non zero values, make either a IEncode Dense or sparse map.
for(int i = apos, j = 0; i < alen; i++, j++)
if(!Double.isNaN(avals[i]))
d.set(j, map.getId(avals[i]));
d.set(j, map.getId(avals[i]));

// Iteration 3 of non zero indexes, make a Offset Encoding to know what cells are zero and not.
// not done yet
Expand All @@ -256,10 +247,7 @@ private static IEncode createFromDense(MatrixBlock m, int col) {

// Iteration 1, make Count HashMap.
for(int i = off; i < end; i += nCol) // jump down through rows.
if(!Double.isNaN(vals[i]))
map.increment(vals[i]);
else
map.increment(0.0);
map.increment(vals[i]);
final int nUnique = map.size();
if(nUnique == 1)
return new ConstEncoding(m.getNumColumns());
Expand Down Expand Up @@ -291,10 +279,7 @@ private static IEncode createFromDense(MatrixBlock m, int col) {
final AMapToData d = MapToFactory.create(nRow, nUnique);
// Iteration 2, make final map
for(int i = off, r = 0; i < end; i += nCol, r++)
if(!Double.isNaN(vals[i]))
d.set(r, map.getId(vals[i]));
else
d.set(r, map.getId(0.0));
d.set(r, map.getId(vals[i]));
return new DenseEncoding(d);
}
}
Expand All @@ -316,11 +301,8 @@ private static IEncode createFromSparse(MatrixBlock m, int col) {
final int[] aix = sb.indexes(r);
final int index = Arrays.binarySearch(aix, apos, alen, col);
if(index >= 0) {
final double v = sb.values(r)[index];
if(!Double.isNaN(v)) {
offsets.appendValue(r);
map.increment(sb.values(r)[index]);
}
offsets.appendValue(r);
map.increment(sb.values(r)[index]);
}
}
if(offsets.size() == 0)
Expand All @@ -341,7 +323,7 @@ private static IEncode createFromSparse(MatrixBlock m, int col) {
final int index = Arrays.binarySearch(aix, apos, alen, col);
if(index >= 0) {
final double v = sb.values(r)[index];
if(index >= 0 && !Double.isNaN(v))
if(index >= 0)
d.set(off++, map.getId(v));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,8 @@

/** Base class for all column selection readers. */
public abstract class ReaderColumnSelection {

protected static final Log LOG = LogFactory.getLog(ReaderColumnSelection.class.getName());

protected static boolean nanEncountered = false;

protected final IColIndex _colIndexes;
protected final DblArray reusableReturn;
protected final double[] reusableArr;
Expand Down Expand Up @@ -114,11 +111,4 @@ private static void checkInput(final MatrixBlock rawBlock, final IColIndex colIn
else if(rl >= ru)
throw new DMLCompressionException("Invalid inverse range for reader " + rl + " to " + ru);
}

protected void warnNaN() {
if(!nanEncountered) {
LOG.warn("NaN value encountered, replaced by 0 in compression, since nan is not supported");
nanEncountered = true;
}
}
}
25 changes: 17 additions & 8 deletions src/main/java/org/apache/sysds/runtime/compress/utils/ACount.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,15 @@

package org.apache.sysds.runtime.compress.utils;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public abstract class ACount<T> {
protected static final Log LOG = LogFactory.getLog(ACount.class.getName());

/** The current count of this element */
public int count;
/** The current ID of this element should be unique for the user. */
public int id;

public abstract ACount<T> next();
Expand Down Expand Up @@ -150,11 +157,6 @@ public final Double key() {
return key;
}

// @Override
// protected final int hashIndex() {
// return hashIndex(key);
// }

@Override
public DCounts sort() {
return (DCounts) super.sort();
Expand All @@ -163,19 +165,19 @@ public DCounts sort() {
@Override
public ACount<Double> get(Double key) {
DCounts e = this;
while(e != null && e.key != key)
while(e != null && !eq(key, e.key))
e = e.next;
return e;
}

@Override
public DCounts inc(Double key, int c, int id) {
DCounts e = this;
while(e.next != null && key != e.key) {
while(e.next != null && !eq(key, e.key)) {
e = e.next;
}

if(key == e.key) {
if(eq(key, e.key)) {
e.count += c;
return e;
}
Expand All @@ -185,6 +187,13 @@ public DCounts inc(Double key, int c, int id) {
}
}

private static boolean eq(double a, double b) {
long al = Double.doubleToRawLongBits(a);
long bl = Double.doubleToRawLongBits(b);
LOG.error(a + " " + b + " " + al + " " + bl + " " + (al == bl));
return al == bl;
}

public static final int hashIndex(double key) {
final long bits = Double.doubleToLongBits(key);
return Math.abs((int) (bits ^ (bits >>> 32)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public int getId(T key) {
public ACount<T> getC(T key) {
int ix = data.length < shortCutSize ? 0 : hash(key) % data.length;
ACount<T> l = data[ix];
return (l != null) ? l.get(key) : null;
return l != null ? l.get(key) : null;
}

public int getOrDefault(T key, int def) {
Expand Down
Loading

0 comments on commit 73410bf

Please sign in to comment.