diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java index 8663263e015..af99d407800 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java @@ -24,6 +24,7 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; @@ -64,8 +65,8 @@ else if(rowCols.size() == 1) /** * Encode a full delta representation of the matrix input taking all rows into account. * - * Note the input matrix should not be delta encoded, but instead while processing, enforcing that we do not allocate - * more memory. + * Note the input matrix should not be delta encoded, but instead while processing, enforcing that we do not + * allocate more memory. * * @param m The input matrix that is not delta encoded and should not be modified * @param transposed If the input matrix is transposed. @@ -81,8 +82,8 @@ public static IEncode createFromMatrixBlockDelta(MatrixBlock m, boolean transpos /** * Encode a delta representation of the matrix input taking the first "sampleSize" rows into account. * - * Note the input matrix should not be delta encoded, but instead while processing, enforcing that we do not allocate - * more memory. + * Note the input matrix should not be delta encoded, but instead while processing, enforcing that we do not + * allocate more memory. * * @param m Input matrix that is not delta encoded and should not be modified * @param transposed If the input matrix is transposed. @@ -90,7 +91,8 @@ public static IEncode createFromMatrixBlockDelta(MatrixBlock m, boolean transpos * @param sampleSize The number of rows to consider for the delta encoding (from the beginning) * @return A delta encoded encoding. */ - public static IEncode createFromMatrixBlockDelta(MatrixBlock m, boolean transposed, IColIndex rowCols, int sampleSize) { + public static IEncode createFromMatrixBlockDelta(MatrixBlock m, boolean transposed, IColIndex rowCols, + int sampleSize) { throw new NotImplementedException(); } @@ -117,19 +119,19 @@ else if(m.isInSparseFormat()) return createFromDense(m, rowCol); } - public static IEncode create(ColGroupConst c){ + public static IEncode create(ColGroupConst c) { return new ConstEncoding(-1); } - public static IEncode create(ColGroupEmpty c){ + public static IEncode create(ColGroupEmpty c) { return new EmptyEncoding(); } - public static IEncode create(AMapToData d){ + public static IEncode create(AMapToData d) { return new DenseEncoding(d); } - public static IEncode create(AMapToData d, AOffset i, int nRow){ + public static IEncode create(AMapToData d, AOffset i, int nRow) { return new SparseEncoding(d, i, nRow); } @@ -137,7 +139,7 @@ private static IEncode createFromDenseTransposed(MatrixBlock m, int row) { final DenseBlock db = m.getDenseBlock(); if(!db.isContiguous()) throw new NotImplementedException("Not Implemented non contiguous dense matrix encoding for sample"); - final DoubleCountHashMap map = new DoubleCountHashMap(16); + final DoubleCountHashMap map = new DoubleCountHashMap(); final int off = db.pos(row); final int nCol = m.getNumColumns(); final int end = off + nCol; @@ -145,10 +147,7 @@ private static IEncode createFromDenseTransposed(MatrixBlock m, int row) { // Iteration 1, make Count HashMap. for(int i = off; i < end; i++) // sequential access - if(!Double.isNaN(vals[i])) - map.increment(vals[i]); - else - map.increment(0.0); + map.increment(vals[i]); final int nUnique = map.size(); @@ -163,17 +162,15 @@ else if(map.getOrDefault(0.0, -1) > nCol / 4) { final IntArrayList offsets = new IntArrayList(nV); final AMapToData d = MapToFactory.create(nV, nUnique - 1); - - // for(int i = off, r = 0, di = 0; i < end; i += nCol, r++){ - for(int i = off, r = 0, di = 0; i < end; i++, r++) { + int di = 0; + for(int i = off, r = 0; i < end; i++, r++) { if(vals[i] != 0) { offsets.appendValue(r); - if(!Double.isNaN(vals[i])) - d.set(di++, map.getId(vals[i])); - else - d.set(di++, map.getId(0.0)); + d.set(di++, map.getId(vals[i])); } } + if(di != nV) + throw new RuntimeException("Did not find equal number of elements " + di + " vs " + nV ); final AOffset o = OffsetFactory.createOffset(offsets); return new SparseEncoding(d, o, nCol); @@ -231,7 +228,7 @@ else if(alen - apos > nCol / 4) { // return a dense encoding return new DenseEncoding(d); } else { // return a sparse encoding - // Create output map + // Create output map final AMapToData d = MapToFactory.create(alen - apos, nUnique); // Iteration 2 of non zero values, make either a IEncode Dense or sparse map. @@ -274,13 +271,15 @@ private static IEncode createFromDense(MatrixBlock m, int col) { final IntArrayList offsets = new IntArrayList(nV); final AMapToData d = MapToFactory.create(nV, nUnique - 1); - - for(int i = off, r = 0, di = 0; i < end; i += nCol, r++) { - if(vals[i] != 0 && !Double.isNaN(vals[i])) { + int di = 0; + for(int i = off, r = 0; i < end; i += nCol, r++) { + if(vals[i] != 0) { offsets.appendValue(r); d.set(di++, map.get(vals[i])); } } + if(di != nV) + throw new DMLRuntimeException("Invalid number of zero."); final AOffset o = OffsetFactory.createOffset(offsets); @@ -288,7 +287,7 @@ private static IEncode createFromDense(MatrixBlock m, int col) { } else { // Allocate counts, and iterate once to replace counts with u ids - + final AMapToData d = MapToFactory.create(nRow, nUnique); // Iteration 2, make final map for(int i = off, r = 0; i < end; i += nCol, r++) diff --git a/src/main/java/org/apache/sysds/runtime/compress/utils/ACountHashMap.java b/src/main/java/org/apache/sysds/runtime/compress/utils/ACountHashMap.java index b88de46f18e..0e37b857ab1 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/utils/ACountHashMap.java +++ b/src/main/java/org/apache/sysds/runtime/compress/utils/ACountHashMap.java @@ -123,17 +123,6 @@ public final ACount[] extractValues() { return ret; } - public void replaceWithUIDsNoZero() { - int i = 0; - for(ACount e : data) { - while(e != null) { - if(e.key().equals(0.0)) - e.count = i++; - e = e.next(); - } - } - } - public T getMostFrequent() { T f = null; int fq = 0; diff --git a/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleCountHashMap.java b/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleCountHashMap.java index 2d16e94fd3e..c8970dbce97 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleCountHashMap.java +++ b/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleCountHashMap.java @@ -55,4 +55,19 @@ public double[] getDictionary() { return ret; } + + public void replaceWithUIDsNoZero() { + int i = 0; + Double z = Double.valueOf(0.0); + for(ACount e : data) { + while(e != null) { + if(!e.key().equals(z)) + e.id = i++; + else + e.id = -1; + e = e.next(); + } + } + + } }