From 8dbfc235cc222b5e7b3fdf9adfc1a09b95c84249 Mon Sep 17 00:00:00 2001 From: baunsgaard Date: Mon, 7 Aug 2023 13:37:40 +0200 Subject: [PATCH] [SYSTEMDS-3608] Cocode shortcut This commit adds a shortcut in que based cocode algorithm that indicate if no cocoding have happened the last 5 tries in the que abort cocoding. Previously cases where cocoding was on the border then all pairs of columns would be tried single threaded. To avoid this the implementation is now avoiding the all pairing if no cocoding is happening. The cost optimization of cocoding have been tuned to add a minor delta on the cost based on the average of the column indexes in the groups. This makes the que based optimizer likely to cocode columns next to each other if they have similar cost, making the cocoding algorithm faster in cases with high corelation between columns next to each other. This commit also adds an TwoRangesIndex, that is useful in the case where almost all columns get cocoded, and the greedy algorithm combine colgroups with thousands of groups. In this case the column arrays combining dominated. Also added in this commit is compression schemes for SDC and refinements for SDC, and a natural update and upgrade scheme progression from initially empty columns to const to DDC. Next up is to add a transition to SDC in case the distribution of values become dominated by specific values. We also add a minor update to AggTernaryOp for compressed with a shortcut that avoids an decompression in L2SVM. Closes #1874 --- .../compress/CompressedMatrixBlock.java | 40 +-- .../CompressedMatrixBlockFactory.java | 6 +- .../compress/cocode/CoCodePriorityQue.java | 47 ++- .../compress/cocode/CoCoderFactory.java | 87 ++++-- .../runtime/compress/cocode/ColIndexes.java | 6 +- .../sysds/runtime/compress/colgroup/ASDC.java | 7 + .../runtime/compress/colgroup/ASDCZero.java | 7 + .../compress/colgroup/ColGroupDDCFOR.java | 4 +- .../compress/colgroup/ColGroupEmpty.java | 3 +- .../compress/colgroup/ColGroupOLE.java | 2 +- .../compress/colgroup/ColGroupRLE.java | 3 +- .../compress/colgroup/ColGroupSDC.java | 6 - .../compress/colgroup/ColGroupSDCFOR.java | 12 +- .../compress/colgroup/ColGroupSDCSingle.java | 6 - .../colgroup/ColGroupSDCSingleZeros.java | 6 - .../compress/colgroup/ColGroupSDCZeros.java | 6 - .../colgroup/ColGroupUncompressed.java | 3 +- .../compress/colgroup/indexes/AColIndex.java | 3 +- .../compress/colgroup/indexes/ArrayIndex.java | 17 + .../colgroup/indexes/ColIndexFactory.java | 29 +- .../compress/colgroup/indexes/IColIndex.java | 34 +- .../compress/colgroup/indexes/RangeIndex.java | 71 +++-- .../colgroup/indexes/SingleIndex.java | 5 + .../compress/colgroup/indexes/TwoIndex.java | 13 +- .../colgroup/indexes/TwoRangesIndex.java | 269 ++++++++++++++++ .../compress/colgroup/mapping/MapToByte.java | 6 +- .../colgroup/scheme/CompressionScheme.java | 278 +++++++++++++++++ .../compress/colgroup/scheme/ConstScheme.java | 34 +- .../compress/colgroup/scheme/DDCScheme.java | 2 +- .../compress/colgroup/scheme/DDCSchemeSC.java | 8 +- .../compress/colgroup/scheme/EmptyScheme.java | 37 ++- .../compress/colgroup/scheme/RLEScheme.java | 66 ++++ .../compress/colgroup/scheme/SDCScheme.java | 85 +++++ .../compress/colgroup/scheme/SDCSchemeMC.java | 216 +++++++++++++ .../compress/colgroup/scheme/SDCSchemeSC.java | 218 +++++++++++++ .../colgroup/scheme/SchemeFactory.java | 6 +- .../runtime/compress/estim/ComEstExact.java | 3 +- .../runtime/compress/estim/ComEstSample.java | 3 +- .../estim/CompressedSizeInfoColGroup.java | 27 +- .../estim/encoding/DenseEncoding.java | 2 +- .../compress/lib/CLALibAggTernaryOp.java | 141 +++++++++ .../compress/lib/CLALibCombineGroups.java | 12 +- .../runtime/compress/lib/CLALibScheme.java | 48 +++ .../runtime/compress/lib/CLALibSlice.java | 11 + .../runtime/compress/lib/CLALibStack.java | 7 +- .../runtime/compress/lib/CLALibUtils.java | 2 + .../compress/utils/DoubleCountHashMap.java | 38 +-- .../runtime/compress/utils/IntArrayList.java | 28 +- .../cp/AggregateTernaryCPInstruction.java | 2 +- .../spark/AggregateTernarySPInstruction.java | 4 +- .../runtime/matrix/data/LibMatrixAgg.java | 13 +- .../runtime/matrix/data/MatrixBlock.java | 58 ++-- .../compress/CompressedCustomTests.java | 4 +- .../compress/CompressedLoggingTests.java | 35 ++- .../compress/CompressedMatrixTest.java | 10 +- .../compress/CompressedTestBase.java | 50 ++- .../compress/colgroup/ColGroupTest.java | 30 +- .../compress/indexes/CustomIndexTest.java | 229 +++++++++++++- .../compress/indexes/IndexesTest.java | 294 +++++++++++++++++- .../compress/indexes/NegativeIndexTest.java | 37 ++- .../compress/lib/CombineGroupsTest.java | 4 +- .../sysds/test/component/misc/ThreadPool.java | 17 +- 62 files changed, 2457 insertions(+), 300 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoRangesIndex.java create mode 100644 src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/CompressionScheme.java create mode 100644 src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/RLEScheme.java create mode 100644 src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/SDCScheme.java create mode 100644 src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/SDCSchemeMC.java create mode 100644 src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/SDCSchemeSC.java create mode 100644 src/main/java/org/apache/sysds/runtime/compress/lib/CLALibAggTernaryOp.java create mode 100644 src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScheme.java diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java index 4a1d4928568..85313234b12 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java @@ -64,6 +64,7 @@ import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseRow; +import org.apache.sysds.runtime.functionobjects.SwapIndex; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.instructions.cp.ScalarObject; @@ -77,7 +78,6 @@ import org.apache.sysds.runtime.matrix.data.RandomMatrixGenerator; import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; import org.apache.sysds.runtime.matrix.operators.AggregateOperator; -import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator; import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.matrix.operators.CMOperator; @@ -587,11 +587,21 @@ else if(isOverlapping()) { @Override public MatrixBlock reorgOperations(ReorgOperator op, MatrixValue ret, int startRow, int startColumn, int length) { - // Allow transpose to be compressed output. In general we need to have a transposed flag on - // the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025 - printDecompressWarning(op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName()); - MatrixBlock tmp = decompress(op.getNumThreads()); - return tmp.reorgOperations(op, ret, startRow, startColumn, length); + if(op.fn instanceof SwapIndex && this.getNumColumns() == 1) { + MatrixBlock tmp = decompress(op.getNumThreads()); + long nz = tmp.setNonZeros(tmp.getNonZeros()); + tmp = new MatrixBlock(tmp.getNumColumns(), tmp.getNumRows(), tmp.getDenseBlockValues()); + tmp.setNonZeros(nz); + return tmp; + } + else { + // Allow transpose to be compressed output. In general we need to have a transposed flag on + // the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025 + String message = op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName(); + MatrixBlock tmp = getUncompressed(message, op.getNumThreads()); + return tmp.reorgOperations(op, ret, startRow, startColumn, length); + } + } public boolean isOverlapping() { @@ -788,24 +798,6 @@ public MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result) { return getUncompressed("sortOperations").sortOperations(right, result); } - @Override - public MatrixBlock aggregateTernaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock m3, MatrixBlock ret, - AggregateTernaryOperator op, boolean inCP) { - boolean m1C = m1 instanceof CompressedMatrixBlock; - boolean m2C = m2 instanceof CompressedMatrixBlock; - boolean m3C = m3 instanceof CompressedMatrixBlock; - printDecompressWarning("aggregateTernaryOperations " + op.aggOp.getClass().getSimpleName() + " " - + op.indexFn.getClass().getSimpleName() + " " + op.aggOp.increOp.fn.getClass().getSimpleName() + " " - + op.binaryFn.getClass().getSimpleName() + " m1,m2,m3 " + m1C + " " + m2C + " " + m3C); - MatrixBlock left = getUncompressed(m1); - MatrixBlock right1 = getUncompressed(m2); - MatrixBlock right2 = getUncompressed(m3); - ret = left.aggregateTernaryOperations(left, right1, right2, ret, op, inCP); - if(ret.getNumRows() == 0 || ret.getNumColumns() == 0) - throw new DMLCompressionException("Invalid output"); - return ret; - } - @Override public MatrixBlock uaggouterchainOperations(MatrixBlock mbLeft, MatrixBlock mbRight, MatrixBlock mbOut, BinaryOperator bOp, AggregateUnaryOperator uaggOp) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java index 5074a695df3..74b6cc8e0a8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java @@ -265,8 +265,10 @@ public static CompressedMatrixBlock createConstant(int numRows, int numCols, dou } private Pair compressMatrix() { - if(mb.getNonZeros() < 0) - throw new DMLCompressionException("Invalid to compress matrices with unknown nonZeros"); + if(mb.getNonZeros() < 0) { + LOG.warn("Recomputing non-zeros since it is unknown in compression"); + mb.recomputeNonZeros(); + } else if(mb instanceof CompressedMatrixBlock && ((CompressedMatrixBlock) mb).isOverlapping()) { LOG.warn("Unsupported recompression of overlapping compression"); return new ImmutablePair<>(mb, null); diff --git a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodePriorityQue.java b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodePriorityQue.java index 0873999f554..dda2efb48f5 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodePriorityQue.java +++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodePriorityQue.java @@ -48,8 +48,7 @@ public class CoCodePriorityQue extends AColumnCoCoder { private static final int COL_COMBINE_THRESHOLD = 1024; - protected CoCodePriorityQue(AComEst sizeEstimator, ACostEstimate costEstimator, - CompressionSettings cs) { + protected CoCodePriorityQue(AComEst sizeEstimator, ACostEstimate costEstimator, CompressionSettings cs) { super(sizeEstimator, costEstimator, cs); } @@ -59,8 +58,8 @@ protected CompressedSizeInfo coCodeColumns(CompressedSizeInfo colInfos, int k) { return colInfos; } - protected static List join(List groups, - AComEst sEst, ACostEstimate cEst, int minNumGroups, int k) { + protected static List join(List groups, AComEst sEst, + ACostEstimate cEst, int minNumGroups, int k) { if(groups.size() > COL_COMBINE_THRESHOLD && k > 1) return combineMultiThreaded(groups, sEst, cEst, minNumGroups, k); @@ -111,16 +110,19 @@ private static List combineBlock(List combineBlock(Queue que, - AComEst sEst, ACostEstimate cEst, int minNumGroups) { + private static List combineBlock(Queue que, AComEst sEst, + ACostEstimate cEst, int minNumGroups) { List ret = new ArrayList<>(); CompressedSizeInfoColGroup l = null; l = que.poll(); int groupNr = ret.size() + que.size(); - while(que.peek() != null && groupNr >= minNumGroups) { + int lastCombine = 0; // if we have not combined in the last 5 tries abort cocoding. + + while(que.peek() != null && groupNr >= minNumGroups && lastCombine < 5) { CompressedSizeInfoColGroup r = que.peek(); CompressedSizeInfoColGroup g = sEst.combine(l, r); + if(g != null) { double costOfJoin = cEst.getCost(g); double costIndividual = cEst.getCost(l) + cEst.getCost(r); @@ -128,20 +130,33 @@ private static List combineBlock(Queue 128) + if(numColumns > 128){ + lastCombine++; ret.add(g); - else + } + else{ + lastCombine = 0; que.add(g); + } } - else + else{ + lastCombine++; ret.add(l); + } } - else + else{ + lastCombine++; ret.add(l); + } l = que.poll(); groupNr = ret.size() + que.size(); } + while(que.peek() != null){ + // empty que + ret.add(l); + l = que.poll(); + } if(l != null) ret.add(l); @@ -153,11 +168,15 @@ private static List combineBlock(Queue getQue(int size, ACostEstimate cEst) { - Comparator comp = Comparator.comparing(x -> cEst.getCost(x)); + Comparator comp = Comparator.comparing(x -> getCost(x, cEst)); Queue que = new PriorityQueue<>(size, comp); return que; } + private static double getCost(CompressedSizeInfoColGroup x, ACostEstimate cEst) { + return cEst.getCost(x) + x.getColumns().avgOfIndex() / 100000; + } + protected static class PQTask implements Callable> { private final List _groups; @@ -167,8 +186,8 @@ protected static class PQTask implements Callable groups, int start, int end, AComEst sEst, - ACostEstimate cEst, int minNumGroups) { + protected PQTask(List groups, int start, int end, AComEst sEst, ACostEstimate cEst, + int minNumGroups) { _groups = groups; _start = start; _end = end; diff --git a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java index a9b4a2bb520..abd12d3f6a8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java @@ -22,16 +22,19 @@ import java.util.ArrayList; import java.util.List; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.CompressionSettings; +import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.cost.ACostEstimate; import org.apache.sysds.runtime.compress.estim.AComEst; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; -import org.apache.sysds.runtime.compress.utils.IntArrayList; public interface CoCoderFactory { + public static final Log LOG = LogFactory.getLog(AColumnCoCoder.class.getName()); /** * The Valid coCoding techniques @@ -53,51 +56,69 @@ public enum PartitionerType { * @param cs The compression settings used in the compression. * @return The estimated (hopefully) best groups of ColGroups. */ - public static CompressedSizeInfo findCoCodesByPartitioning(AComEst est, CompressedSizeInfo colInfos, - int k, ACostEstimate costEstimator, CompressionSettings cs) { + public static CompressedSizeInfo findCoCodesByPartitioning(AComEst est, CompressedSizeInfo colInfos, int k, + ACostEstimate costEstimator, CompressionSettings cs) { // Use column group partitioner to create partitions of columns AColumnCoCoder co = createColumnGroupPartitioner(cs.columnPartitioner, est, costEstimator, cs); // Find out if any of the groups are empty. - boolean containsEmpty = false; - for(CompressedSizeInfoColGroup g : colInfos.compressionInfo) { - if(g.isEmpty()) { - containsEmpty = true; - break; - } - } + final boolean containsEmptyOrConst = containsEmptyOrConst(colInfos); - // if there are no empty columns then try cocode algorithms for all columns - if(!containsEmpty) + // if there are no empty or const columns then try cocode algorithms for all columns + if(!containsEmptyOrConst) return co.coCodeColumns(colInfos, k); + else { + // filtered empty groups + final List emptyCols = new ArrayList<>(); + // filtered const groups + final List constCols = new ArrayList<>(); + // filtered groups -- in the end starting with all groups + final List groups = new ArrayList<>(); + + final int nRow = colInfos.compressionInfo.get(0).getNumRows(); + + // filter groups + for(int i = 0; i < colInfos.compressionInfo.size(); i++) { + CompressedSizeInfoColGroup g = colInfos.compressionInfo.get(i); + if(g.isEmpty()) + emptyCols.add(g.getColumns()); + else if(g.isConst()) + constCols.add(g.getColumns()); + else + groups.add(g); + } - // extract all empty columns - IntArrayList emptyCols = new IntArrayList(); - List notEmpty = new ArrayList<>(); + // overwrite groups. + colInfos.compressionInfo = groups; + + // cocode remaining groups + if(!groups.isEmpty()) { + colInfos = co.coCodeColumns(colInfos, k); + } - for(CompressedSizeInfoColGroup g : colInfos.compressionInfo) { - if(g.isEmpty()) - emptyCols.appendValue(g.getColumns().get(0)); - else - notEmpty.add(g); - } + // add empty + if(emptyCols.size() > 0) { + final IColIndex idx = ColIndexFactory.combineIndexes(emptyCols); + colInfos.compressionInfo.add(new CompressedSizeInfoColGroup(idx, nRow, CompressionType.EMPTY)); + } - final int nRow = colInfos.compressionInfo.get(0).getNumRows(); + // add const + if(constCols.size() > 0) { + final IColIndex idx = ColIndexFactory.combineIndexes(constCols); + colInfos.compressionInfo.add(new CompressedSizeInfoColGroup(idx, nRow, CompressionType.CONST)); + } - final IColIndex idx = ColIndexFactory.create(emptyCols); - if(notEmpty.isEmpty()) { // if all empty (unlikely but could happen) - CompressedSizeInfoColGroup empty = new CompressedSizeInfoColGroup(idx, nRow); - return new CompressedSizeInfo(empty); - } + return colInfos; - // cocode all not empty columns - colInfos.compressionInfo = notEmpty; - colInfos = co.coCodeColumns(colInfos, k); + } + } - // add empty columns back as single columns - colInfos.compressionInfo.add(new CompressedSizeInfoColGroup(idx, nRow)); - return colInfos; + private static boolean containsEmptyOrConst(CompressedSizeInfo colInfos) { + for(CompressedSizeInfoColGroup g : colInfos.compressionInfo) + if(g.isEmpty() || g.isConst()) + return true; + return false; } private static AColumnCoCoder createColumnGroupPartitioner(PartitionerType type, AComEst est, diff --git a/src/main/java/org/apache/sysds/runtime/compress/cocode/ColIndexes.java b/src/main/java/org/apache/sysds/runtime/compress/cocode/ColIndexes.java index 3d7e91f9213..dcdcbe464c0 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/cocode/ColIndexes.java +++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/ColIndexes.java @@ -45,10 +45,6 @@ public boolean contains(ColIndexes a, ColIndexes b) { if(a == null || b == null) return false; - int id = _indexes.findIndex(a._indexes.get(0)); - if(id >= 0) - return true; - id = _indexes.findIndex(b._indexes.get(0)); - return id >= 0; + return _indexes.contains(a._indexes.get(0)) || _indexes.contains(b._indexes.get(0)); } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDC.java index 96c3dda02dd..3c63cca7d2f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDC.java @@ -22,6 +22,8 @@ import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; +import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; +import org.apache.sysds.runtime.compress.colgroup.scheme.SDCScheme; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; import org.apache.sysds.runtime.compress.estim.EstimationFactors; @@ -62,4 +64,9 @@ public final CompressedSizeInfoColGroup getCompressionInfo(int nRow) { EstimationFactors ef = new EstimationFactors(getNumValues(), _numRows, getNumberOffsets(), _dict.getSparsity()); return new CompressedSizeInfoColGroup(_colIndexes, ef, nRow, getCompType(),getEncoding()); } + + @Override + public ICLAScheme getCompressionScheme() { + return SDCScheme.create(this); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java index 041458621d8..23ce0be2556 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java @@ -24,6 +24,8 @@ import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; +import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; +import org.apache.sysds.runtime.compress.colgroup.scheme.SDCScheme; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; import org.apache.sysds.runtime.compress.estim.EstimationFactors; import org.apache.sysds.runtime.data.DenseBlock; @@ -230,4 +232,9 @@ public final CompressedSizeInfoColGroup getCompressionInfo(int nRow) { EstimationFactors ef = new EstimationFactors(getNumValues(), _numRows, getNumberOffsets(), _dict.getSparsity()); return new CompressedSizeInfoColGroup(_colIndexes, ef, nRow, getCompType(),getEncoding()); } + + @Override + public ICLAScheme getCompressionScheme() { + return SDCScheme.create(this); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java index 0b44440e71a..70029e21ec0 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java @@ -459,12 +459,12 @@ public AColGroup appendNInternal(AColGroup[] g) { @Override public ICLAScheme getCompressionScheme() { - return null; + throw new NotImplementedException(); } @Override public AColGroup recompress() { - return this; + throw new NotImplementedException(); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java index c908b267e3e..bb4a63afbbe 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java @@ -352,7 +352,8 @@ public AColGroup recompress() { @Override public CompressedSizeInfoColGroup getCompressionInfo(int nRow) { EstimationFactors ef = new EstimationFactors(getNumValues(), 1, 0, 0.0); - return new CompressedSizeInfoColGroup(_colIndexes, ef, estimateInMemorySize(), CompressionType.CONST, getEncoding()); + return new CompressedSizeInfoColGroup(_colIndexes, ef, estimateInMemorySize(), CompressionType.EMPTY, + getEncoding()); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java index bac40cb6875..cdec096da83 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java @@ -671,7 +671,7 @@ public AColGroup appendNInternal(AColGroup[] g) { @Override public ICLAScheme getCompressionScheme() { - return null; + throw new NotImplementedException(); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java index 692e2496a80..1840d248764 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java @@ -35,6 +35,7 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffsetIterator; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; +import org.apache.sysds.runtime.compress.colgroup.scheme.RLEScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; import org.apache.sysds.runtime.data.DenseBlock; @@ -982,7 +983,7 @@ public AColGroup appendNInternal(AColGroup[] g) { @Override public ICLAScheme getCompressionScheme() { - return null; + return RLEScheme.create(this); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java index 03012c7c682..e4977c05595 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java @@ -38,7 +38,6 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; -import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; @@ -623,11 +622,6 @@ public AColGroup appendNInternal(AColGroup[] g) { return create(_colIndexes, sumRows, _dict, _defaultTuple, no, nd, null); } - @Override - public ICLAScheme getCompressionScheme() { - return null; - } - @Override public AColGroup recompress() { return this; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java index 9783344ab42..294a47c4372 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java @@ -62,7 +62,7 @@ * with no modifications. * */ -public class ColGroupSDCFOR extends ASDC implements IMapToDataGroup , IFrameOfReferenceGroup{ +public class ColGroupSDCFOR extends ASDC implements IMapToDataGroup, IFrameOfReferenceGroup { private static final long serialVersionUID = 3883228464052204203L; @@ -486,11 +486,6 @@ public AColGroup appendNInternal(AColGroup[] g) { return create(_colIndexes, sumRows, _dict, no, nd, null, _reference); } - @Override - public ICLAScheme getCompressionScheme() { - return null; - } - @Override public AColGroup recompress() { return this; @@ -521,6 +516,11 @@ public int getNumberOffsets() { return _data.size(); } + @Override + public ICLAScheme getCompressionScheme() { + throw new NotImplementedException(); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java index 7f43df5f8f5..1182d80ba57 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java @@ -36,7 +36,6 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; import org.apache.sysds.runtime.compress.colgroup.offset.AOffsetIterator; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; -import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; @@ -580,11 +579,6 @@ public AColGroup appendNInternal(AColGroup[] g) { return create(_colIndexes, sumRows, _dict, _defaultTuple, no, null); } - @Override - public ICLAScheme getCompressionScheme() { - return null; - } - @Override public AColGroup recompress() { return this; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java index 9392b1f23fd..69a93d8ef15 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java @@ -38,7 +38,6 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; import org.apache.sysds.runtime.compress.colgroup.offset.AOffsetIterator; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; -import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; @@ -862,11 +861,6 @@ public AColGroup appendNInternal(AColGroup[] g) { return create(_colIndexes, sumRows, _dict, no, null); } - @Override - public ICLAScheme getCompressionScheme() { - return null; - } - @Override public AColGroup recompress() { return this; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java index 68eb2144959..4a6f6b50b83 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java @@ -38,7 +38,6 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; -import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; @@ -793,11 +792,6 @@ public AColGroup appendNInternal(AColGroup[] g) { return create(_colIndexes, sumRows, _dict, no, nd, null); } - @Override - public ICLAScheme getCompressionScheme() { - return null; - } - @Override public AColGroup recompress() { return this; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java index 3bbbc5fcc3d..d039e6ff3d1 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java @@ -25,6 +25,7 @@ import java.util.Arrays; import java.util.List; +import org.apache.sysds.runtime.compress.colgroup.scheme.SchemeFactory; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; @@ -810,7 +811,7 @@ public AColGroup appendNInternal(AColGroup[] g) { @Override public ICLAScheme getCompressionScheme() { - return null; + return SchemeFactory.create(_colIndexes, CompressionType.UNCOMPRESSED); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/AColIndex.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/AColIndex.java index cf22ba0d7b6..df4685a65d6 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/AColIndex.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/AColIndex.java @@ -69,10 +69,11 @@ private static int hashCode(IIterate it) { @Override public boolean containsAny(IColIndex idx) { - IIterate it = idx.iterator(); + final IIterate it = idx.iterator(); while(it.hasNext()) if(contains(it.next())) return true; + return false; } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ArrayIndex.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ArrayIndex.java index 711236cb29a..c03aba628bc 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ArrayIndex.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ArrayIndex.java @@ -147,6 +147,13 @@ else if(other instanceof RangeIndex) public IColIndex combine(IColIndex other) { final int sr = other.size(); final int sl = size(); + final int maxCombined = Math.max(this.get(this.size() - 1), other.get(other.size() - 1)); + final int minCombined = Math.min(this.get(0), other.get(0)); + if(sr + sl == maxCombined - minCombined + 1) { + return new RangeIndex(minCombined, maxCombined + 1); + } + + // LOG.error("Combining Worst " + this + " " + other); final int[] ret = new int[sr + sl]; int pl = 0; int pr = 0; @@ -204,10 +211,20 @@ public IColIndex sort() { @Override public boolean contains(int i) { + if(i < cols[0] || i > cols[cols.length - 1]) + return false; int id = Arrays.binarySearch(cols, 0, cols.length, i); return id >= 0; } + @Override + public double avgOfIndex() { + double s = 0.0; + for(int i = 0; i < cols.length; i++) + s += cols[i]; + return s / cols.length; + } + protected class ArrayIterator implements IIterate { int id = 0; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ColIndexFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ColIndexFactory.java index 4ecde22ff48..fd929b8a1aa 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ColIndexFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ColIndexFactory.java @@ -46,6 +46,8 @@ public static IColIndex read(DataInput in) throws IOException { return ArrayIndex.read(in); case RANGE: return RangeIndex.read(in); + case TWORANGE: + return TwoRangesIndex.read(in); default: throw new DMLCompressionException("Failed reading column index of type: " + t); } @@ -58,7 +60,7 @@ public static IColIndex createI(int... indexes) { public static IColIndex create(int[] indexes) { if(indexes.length <= 0) throw new DMLRuntimeException("Invalid length to create index from : " + indexes.length); - if(indexes.length == 1) + else if(indexes.length == 1) return new SingleIndex(indexes[0]); else if(indexes.length == 2) return new TwoIndex(indexes[0], indexes[1]); @@ -82,6 +84,13 @@ else if(RangeIndex.isValidRange(indexes)) return new ArrayIndex(indexes.extractValues(true)); } + /** + * Create an Index range of the given values + * + * @param l Lower bound (inclusive) + * @param u Upper bound (not inclusive) + * @return An Index + */ public static IColIndex create(int l, int u) { if(u - l <= 0) throw new DMLRuntimeException("Invalid range: " + l + " " + u); @@ -133,6 +142,24 @@ public static IColIndex combine(List gs) { return create(resCols); } + public static IColIndex combineIndexes(List idx) { + int numCols = 0; + for(IColIndex g : idx) + numCols += g.size(); + + int[] resCols = new int[numCols]; + + int index = 0; + for(IColIndex g : idx) { + final IIterate it = g.iterator(); + while(it.hasNext()) + resCols[index++] = it.next(); + } + + Arrays.sort(resCols); + return create(resCols); + } + public static IColIndex combine(AColGroup a, AColGroup b) { return combine(a.getColIndices(), b.getColIndices()); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java index 5163998ef89..60c2cec4b23 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java @@ -28,7 +28,7 @@ public interface IColIndex { public static enum ColIndexType { - SINGLE, TWO, ARRAY, RANGE, UNKNOWN; + SINGLE, TWO, ARRAY, RANGE, TWORANGE, UNKNOWN; } /** @@ -94,6 +94,25 @@ public static enum ColIndexType { */ public int findIndex(int i); + /** + * Slice the range given. + * + * The slice result is an object containing the indexes in the original array to slice out and a new index for the + * sliced columns offset by l. + * + * Example: + * + * ArrayIndex(1,3,5).slice(2,6) + * + * returns + * + * SliceResult(1,3,ArrayIndex(1,3)) + * + * + * @param l inclusive lower bound + * @param u exclusive upper bound + * @return A slice result + */ public SliceResult slice(int l, int u); @Override @@ -186,6 +205,13 @@ public static enum ColIndexType { */ public boolean containsAny(IColIndex idx); + /** + * Get the average of this index. We use this to sort the priority que when combining equivalent costly groups + * + * @return The average of the indexes. + */ + public double avgOfIndex(); + /** A Class for slice results containing indexes for the slicing of dictionaries, and the resulting column index */ public static class SliceResult { /** Start index to slice inside the dictionary */ @@ -195,6 +221,12 @@ public static class SliceResult { /** The already modified column index to return on slices */ public final IColIndex ret; + /** + * The slice result + * @param idStart The starting index + * @param idEnd The ending index (not inclusive) + * @param ret The resulting IColIndex + */ protected SliceResult(int idStart, int idEnd, IColIndex ret) { this.idStart = idStart; this.idEnd = idEnd; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/RangeIndex.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/RangeIndex.java index 7705c586d14..bbe5aeb8a5c 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/RangeIndex.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/RangeIndex.java @@ -22,6 +22,7 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; +import java.util.Arrays; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.utils.IntArrayList; @@ -52,7 +53,7 @@ public RangeIndex(int nCol) { * Construct an range index with lower and upper values given. * * @param l lower index - * @param u Upper index + * @param u Upper index not inclusive */ public RangeIndex(int l, int u) { this.l = l; @@ -73,7 +74,7 @@ public int get(int i) { } @Override - public IColIndex shift(int i) { + public RangeIndex shift(int i) { return new RangeIndex(l + i, u + i); } @@ -121,7 +122,6 @@ else if(i < u) @Override public SliceResult slice(int l, int u) { - if(u <= this.l) return new SliceResult(0, 0, null); else if(l >= this.u) @@ -129,9 +129,11 @@ else if(l >= this.u) else if(l <= this.l && u >= this.u) return new SliceResult(0, size(), new RangeIndex(this.l - l, this.u - l)); else { - int offL = Math.max(l, this.l) - this.l; - int offR = Math.min(u, this.u) - this.l; - return new SliceResult(offL, offR, new RangeIndex(Math.max(l, this.l) - l, Math.min(u, this.u) - l)); + int maxL = Math.max(l, this.l); + int minU = Math.min(u, this.u); + int offL = maxL - this.l; + int offR = minU - this.l; + return new SliceResult(offL, offR, new RangeIndex(maxL - l, minU - l )); } } @@ -154,6 +156,16 @@ public IColIndex combine(IColIndex other) { else if(v == u) return new RangeIndex(l, u + 1); } + if(other instanceof RangeIndex) { + if(other.get(0) == u) + return new RangeIndex(l, other.get(other.size() - 1) + 1); + else if(other.get(other.size() - 1) == l - 1) + return new RangeIndex(other.get(0), u); + else if(other.get(0) < this.get(0)) + return new TwoRangesIndex((RangeIndex) other, this); + else + return new TwoRangesIndex(this, (RangeIndex) other); + } final int sr = other.size(); final int sl = size(); @@ -186,18 +198,6 @@ public boolean isContiguous() { return true; } - @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append(this.getClass().getSimpleName()); - sb.append("["); - sb.append(l); - sb.append(" -> "); - sb.append(u); - sb.append("]"); - return sb.toString(); - } - protected static boolean isValidRange(int[] indexes) { return isValidRange(indexes, indexes.length); } @@ -210,10 +210,14 @@ private static boolean isValidRange(final int[] indexes, final int length) { int len = length; int first = indexes[0]; int last = indexes[length - 1]; - if(last - first + 1 == len) { + + final boolean isPossibleFistAndLast = last - first + 1 >= len; + if(!isPossibleFistAndLast) + throw new DMLCompressionException("Invalid Index " + Arrays.toString(indexes)); + else if(last - first + 1 == len) { for(int i = 1; i < length; i++) - if(indexes[i - 1] > indexes[i]) - return false; + if(indexes[i - 1] >= indexes[i]) + throw new DMLCompressionException("Invalid Index"); return true; } else @@ -240,6 +244,31 @@ public boolean contains(int i) { return l <= i && i < u; } + @Override + public double avgOfIndex() { + double diff = u - 1 - l; + // double s = l * diff + diff * diff * 0.5; + // return s / diff; + return l + diff * 0.5; + } + + @Override + public int hashCode() { + return 31 * l + u; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(this.getClass().getSimpleName()); + sb.append("["); + sb.append(l); + sb.append(" -> "); + sb.append(u); + sb.append("]"); + return sb.toString(); + } + protected class RangeIterator implements IIterate { int cl = l; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/SingleIndex.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/SingleIndex.java index 5d10b2a39d4..97325460537 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/SingleIndex.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/SingleIndex.java @@ -130,6 +130,11 @@ public boolean contains(int i) { return i == idx; } + @Override + public double avgOfIndex() { + return idx; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoIndex.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoIndex.java index 9e3e8480f80..16305d2d4e9 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoIndex.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoIndex.java @@ -113,16 +113,16 @@ public IColIndex combine(IColIndex other) { if(other instanceof SingleIndex) { int otherV = other.get(0); if(otherV < id1) - return new ArrayIndex(new int[] {otherV, id1, id2}); + return ColIndexFactory.create(new int[] {otherV, id1, id2}); else if(otherV < id2) - return new ArrayIndex(new int[] {id1, otherV, id2}); + return ColIndexFactory.create(new int[] {id1, otherV, id2}); else - return new ArrayIndex(new int[] {id1, id2, otherV}); + return ColIndexFactory.create(new int[] {id1, id2, otherV}); } else if(other instanceof TwoIndex) { int[] vals = new int[] {other.get(0), other.get(1), id1, id2}; Arrays.sort(vals); - return new ArrayIndex(vals); + return ColIndexFactory.create(vals); } else return other.combine(this); @@ -159,6 +159,11 @@ public boolean contains(int i) { return i == id1 || i == id2; } + @Override + public double avgOfIndex() { + return (id1 + id2) * 0.5; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoRangesIndex.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoRangesIndex.java new file mode 100644 index 00000000000..51634b92692 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoRangesIndex.java @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.colgroup.indexes; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.sysds.runtime.compress.DMLCompressionException; + +public class TwoRangesIndex extends AColIndex { + + /** The lower index range */ + private final RangeIndex idx1; + /** The upper index range */ + private final RangeIndex idx2; + + public TwoRangesIndex(RangeIndex lower, RangeIndex higher) { + this.idx1 = lower; + this.idx2 = higher; + } + + @Override + public int size() { + return idx1.size() + idx2.size(); + } + + @Override + public int get(int i) { + if(i < idx1.size()) + return idx1.get(i); + else + return idx2.get(i - idx1.size()); + } + + @Override + public IColIndex shift(int i) { + return new TwoRangesIndex(idx1.shift(i), idx2.shift(i)); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeByte(ColIndexType.TWORANGE.ordinal()); + out.writeInt(idx1.get(0)); + out.writeInt(idx1.size()); + out.writeInt(idx2.get(0)); + out.writeInt(idx2.size()); + } + + public static TwoRangesIndex read(DataInput in) throws IOException { + int l1 = in.readInt(); + int u1 = in.readInt() + l1; + int l2 = in.readInt(); + int u2 = in.readInt() + l2; + return new TwoRangesIndex(new RangeIndex(l1, u1), new RangeIndex(l2, u2)); + } + + @Override + public long getExactSizeOnDisk() { + return 1 + 4 + 4 + 4 + 4; + } + + @Override + public long estimateInMemorySize() { + return estimateInMemorySizeStatic(); + } + + public static long estimateInMemorySizeStatic() { + return 16 + 8 + 8 + RangeIndex.estimateInMemorySizeStatic() * 2; + } + + @Override + public IIterate iterator() { + return new TwoRangesIterator(); + } + + @Override + public int findIndex(int i) { + int aix = idx1.findIndex(i); + if(aix < -1 * idx1.size()) { + int bix = idx2.findIndex(i); + if(bix < 0) + return aix + bix + 1; + else + return idx1.size() + bix; + } + else + return aix; + + } + + @Override + public SliceResult slice(int l, int u) { + if(u <= idx1.get(0)) + return new SliceResult(0, 0, null); + else if(l >= idx2.get(idx2.size() - 1)) + return new SliceResult(0, 0, null); + else if(l <= idx1.get(0) && u >= idx2.get(idx2.size() - 1)) { + RangeIndex ids1 = idx1.shift(-l); + RangeIndex ids2 = idx2.shift(-l); + return new SliceResult(0, size(), new TwoRangesIndex(ids1, ids2)); + } + + SliceResult sa = idx1.slice(l, u); + SliceResult sb = idx2.slice(l, u); + if(sa.ret == null) { + return new SliceResult(idx1.size() + sb.idStart, idx1.size() + sb.idEnd, sb.ret); + } + else if(sb.ret == null) + // throw new NotImplementedException(); + return sa; + else { + IColIndex c = sa.ret.combine(sb.ret); + return new SliceResult(sa.idStart, sa.idStart + sb.idEnd, c); + } + } + + @Override + public boolean equals(IColIndex other) { + if(other instanceof TwoRangesIndex) { + TwoRangesIndex otri = (TwoRangesIndex) other; + return idx1.equals(otri.idx1) && idx2.equals(otri.idx2); + } + else if(other instanceof RangeIndex) + return false; + else + return other.equals(this); + } + + @Override + public IColIndex combine(IColIndex other) { + final int sr = other.size(); + final int sl = size(); + final int[] ret = new int[sr + sl]; + + int pl = 0; + int pr = 0; + int i = 0; + while(pl < sl && pr < sr) { + final int vl = get(pl); + final int vr = other.get(pr); + if(vl < vr) { + ret[i++] = vl; + pl++; + } + else { + ret[i++] = vr; + pr++; + } + } + while(pl < sl) + ret[i++] = get(pl++); + while(pr < sr) + ret[i++] = other.get(pr++); + return ColIndexFactory.create(ret); + } + + @Override + public boolean isContiguous() { + return false; + } + + @Override + public int[] getReorderingIndex() { + throw new DMLCompressionException("not valid to get reordering Index for range"); + } + + @Override + public boolean isSorted() { + return true; + } + + @Override + public IColIndex sort() { + throw new DMLCompressionException("range is always sorted"); + } + + @Override + public boolean contains(int i) { + return idx1.contains(i) || idx2.contains(i); + } + + @Override + public double avgOfIndex() { + return (idx1.avgOfIndex() * idx1.size() + idx2.avgOfIndex() * idx2.size()) / size(); + } + + @Override + public int hashCode() { + // 811 is a prime. + return idx1.hashCode() * 811 + idx2.hashCode(); + } + + @Override + public boolean containsAny(IColIndex idx) { + return idx1.containsAny(idx) || idx2.containsAny(idx); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(this.getClass().getSimpleName()); + sb.append("["); + sb.append(idx1.get(0)); + sb.append(" -> "); + sb.append(idx1.get(idx1.size())); + sb.append(" And "); + sb.append(idx2.get(0)); + sb.append(" -> "); + sb.append(idx2.get(idx2.size())); + sb.append("]"); + return sb.toString(); + } + + protected class TwoRangesIterator implements IIterate { + IIterate a = idx1.iterator(); + IIterate b = idx2.iterator(); + boolean aDone = false; + + @Override + public int next() { + if(!aDone) { + int v = a.next(); + aDone = !a.hasNext(); + return v; + } + else + return b.next(); + } + + @Override + public boolean hasNext() { + return a.hasNext() || b.hasNext(); + } + + @Override + public int v() { + if(!aDone) + return a.v(); + else + return b.v(); + } + + @Override + public int i() { + if(!aDone) + return a.i(); + else + return a.i() + b.i(); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java index b2d8623eaf7..8352141be39 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java @@ -110,16 +110,14 @@ public void write(DataOutput out) throws IOException { protected void writeBytes(DataOutput out) throws IOException { out.writeInt(getUnique()); out.writeInt(_data.length); - for(int i = 0; i < _data.length; i++) - out.writeByte(_data[i]); + out.write(_data); } protected static MapToByte readFields(DataInput in) throws IOException { final int unique = in.readInt(); final int length = in.readInt(); final byte[] data = new byte[length]; - for(int i = 0; i < length; i++) - data[i] = in.readByte(); + in.readFully(data); return new MapToByte(unique, data); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/CompressionScheme.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/CompressionScheme.java new file mode 100644 index 00000000000..51c3ddffe0a --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/CompressionScheme.java @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.colgroup.scheme; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; + +import org.apache.commons.lang.NotImplementedException; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.CommonThreadPool; + +/** + * A Class that contains a full compression scheme that can be applied to MatrixBlocks. + */ +public class CompressionScheme { + + protected static final Log LOG = LogFactory.getLog(CompressionScheme.class.getName()); + + private final ICLAScheme[] encodings; + + public CompressionScheme(ICLAScheme[] encodings) { + this.encodings = encodings; + } + + /** + * Get the encoding in a specific index. + * + * @param i the index + * @return The encoding in that index + */ + public ICLAScheme get(int i) { + return encodings[i]; + } + + /** + * Encode the given matrix block, it is assumed that the given MatrixBlock already fit the current scheme. + * + * @param mb A MatrixBlock given that should fit the scheme + * @return A Compressed instance of the given matrixBlock; + */ + public CompressedMatrixBlock encode(MatrixBlock mb) { + if(mb instanceof CompressedMatrixBlock) + throw new NotImplementedException("Not implemented schema encode/apply on an already compressed MatrixBlock"); + + List ret = new ArrayList<>(encodings.length); + + for(int i = 0; i < encodings.length; i++) + ret.add(encodings[i].encode(mb)); + + return new CompressedMatrixBlock(mb.getNumRows(), mb.getNumColumns(), mb.getNonZeros(), false, ret); + } + + /** + * Encode the given matrix block, it is assumed that the given MatrixBlock already fit the current scheme. + * + * @param mb A MatrixBlock given that should fit the scheme + * @param k The parallelization degree + * @return A Compressed instance of the given matrixBlock; + */ + public CompressedMatrixBlock encode(MatrixBlock mb, int k) { + if(k == 1) + return encode(mb); + final ExecutorService pool = CommonThreadPool.get(k); + try { + + List tasks = new ArrayList<>(); + for(int i = 0; i < encodings.length; i++) + tasks.add(new EncodeTask(encodings[i], mb)); + + List ret = new ArrayList<>(encodings.length); + for(Future t : pool.invokeAll(tasks)) + ret.add(t.get()); + + return new CompressedMatrixBlock(mb.getNumRows(), mb.getNumColumns(), mb.getNonZeros(), false, ret); + + } + catch(Exception e) { + throw new DMLCompressionException("Failed encoding", e); + } + finally { + pool.shutdown(); + } + } + + /** + * Update the encodings contained to also enable compression of the given mb. + * + * @param mb The matrixBlock to enable compression on. + * @return The updated scheme. (It is updated in place) + */ + public CompressionScheme update(MatrixBlock mb) { + if(mb instanceof CompressedMatrixBlock) + throw new NotImplementedException("Not implemented schema encode/apply on an already compressed MatrixBlock"); + + for(int i = 0; i < encodings.length; i++) + encodings[i] = encodings[i].update(mb); + + return this; + + } + + /** + * Update the encodings contained to also enable compression of the given mb. + * + * @param mb The matrixBlock to enable compression on. + * @param k The parallelization degree + * @return The updated scheme. (It is updated in place) + */ + public CompressionScheme update(MatrixBlock mb, int k) { + if(k == 1) + return update(mb); + final ExecutorService pool = CommonThreadPool.get(k); + try { + + List tasks = new ArrayList<>(); + for(int i = 0; i < encodings.length; i++) + tasks.add(new UpdateTask(encodings[i], mb)); + + List> ret = pool.invokeAll(tasks); + + for(int i = 0; i < encodings.length; i++) + encodings[i] = ret.get(i).get(); + + return this; + } + catch(Exception e) { + throw new DMLCompressionException("Failed encoding", e); + } + finally { + pool.shutdown(); + } + } + + /** Extract a compression scheme for the given matrix block */ + + /** + * Extract a compression scheme for the given matrix block + * + * @param cmb The given compressed matrix block + * @return A Compression scheme that can be applied to new encodings. + */ + public static CompressionScheme getScheme(CompressedMatrixBlock cmb) { + if(cmb.isOverlapping()) + throw new DMLCompressionException("Invalid to extract CompressionScheme from an overlapping compression"); + + List gs = cmb.getColGroups(); + + ICLAScheme[] ret = new ICLAScheme[gs.size()]; + + for(int i = 0; i < gs.size(); i++) + ret[i] = gs.get(i).getCompressionScheme(); + + return new CompressionScheme(ret); + } + + public CompressedMatrixBlock updateAndEncode(MatrixBlock mb, int k) { + if(k == 1) + return updateAndEncode(mb); + final ExecutorService pool = CommonThreadPool.get(k); + try { + + List tasks = new ArrayList<>(); + for(int i = 0; i < encodings.length; i++) + tasks.add(new UpdateAndEncodeTask(i, encodings[i], mb)); + + List ret = new ArrayList<>(encodings.length); + for(Future t : pool.invokeAll(tasks)) + ret.add(t.get()); + + return new CompressedMatrixBlock(mb.getNumRows(), mb.getNumColumns(), mb.getNonZeros(), false, ret); + + } + catch(Exception e) { + throw new DMLCompressionException("Failed encoding", e); + } + finally { + pool.shutdown(); + } + } + + public CompressedMatrixBlock updateAndEncode(MatrixBlock mb) { + if(mb instanceof CompressedMatrixBlock) + throw new NotImplementedException("Not implemented schema encode/apply on an already compressed MatrixBlock"); + + List ret = new ArrayList<>(encodings.length); + + for(int i = 0; i < encodings.length; i++) { + encodings[i] = encodings[i].update(mb); + ret.add(encodings[i].encode(mb)); + } + + return new CompressedMatrixBlock(mb.getNumRows(), mb.getNumColumns(), mb.getNonZeros(), false, ret); + + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(this.getClass().getSimpleName()); + sb.append("\n"); + sb.append(Arrays.toString(encodings)); + return sb.toString(); + } + + protected class EncodeTask implements Callable { + final ICLAScheme enc; + final MatrixBlock mb; + + protected EncodeTask(ICLAScheme enc, MatrixBlock mb) { + this.enc = enc; + this.mb = mb; + } + + @Override + public AColGroup call() throws Exception { + return enc.encode(mb); + } + } + + protected class UpdateTask implements Callable { + final ICLAScheme enc; + final MatrixBlock mb; + + protected UpdateTask(ICLAScheme enc, MatrixBlock mb) { + this.enc = enc; + this.mb = mb; + } + + @Override + public ICLAScheme call() throws Exception { + return enc.update(mb); + } + } + + protected class UpdateAndEncodeTask implements Callable { + final int i; + final ICLAScheme enc; + final MatrixBlock mb; + + protected UpdateAndEncodeTask(int i, ICLAScheme enc, MatrixBlock mb) { + this.i = i; + this.enc = enc; + this.mb = mb; + } + + @Override + public AColGroup call() throws Exception { + encodings[i] = enc.update(mb); + return enc.encode(mb); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/ConstScheme.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/ConstScheme.java index 3f96e78371d..e82874fe0d9 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/ConstScheme.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/ConstScheme.java @@ -19,8 +19,10 @@ package org.apache.sysds.runtime.compress.colgroup.scheme; -import org.apache.commons.lang3.NotImplementedException; +import java.util.Arrays; + import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -40,8 +42,8 @@ public static ICLAScheme create(ColGroupConst g) { return new ConstScheme(g.getColIndices(), g.getValues(), true); } - public static ICLAScheme create(IColIndex cols) { - return new ConstScheme(cols, new double[cols.size()], false); + public static ICLAScheme create(IColIndex cols, double[] vals) { + return new ConstScheme(cols, vals, false); } @Override @@ -51,13 +53,37 @@ protected IColIndex getColIndices() { @Override public ICLAScheme update(MatrixBlock data, IColIndex columns) { - throw new NotImplementedException(); + final int nRow = data.getNumRows(); + final int nColScheme = vals.length; + for(int r = 0; r < nRow; r++) + for(int c = 0; c < nColScheme; c++) { + final double v = data.quickGetValue(r, cols.get(c)); + if(Double.compare(v, vals[c]) != 0) + return updateToDDC(data, columns); + } + return this; + } + + private ICLAScheme updateToDDC(MatrixBlock data, IColIndex columns) { + return SchemeFactory.create(columns, CompressionType.DDC).update(data, columns); } @Override public AColGroup encode(MatrixBlock data, IColIndex columns) { validate(data, columns); + // we assume that it is always valid. return ColGroupConst.create(columns, vals); } + @Override + public final String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(this.getClass().getSimpleName()); + sb.append(" Cols: "); + sb.append(cols); + sb.append(" Def: "); + sb.append(Arrays.toString(vals)); + return sb.toString(); + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/DDCScheme.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/DDCScheme.java index 75a63ead600..2401946cae2 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/DDCScheme.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/DDCScheme.java @@ -25,7 +25,7 @@ public abstract class DDCScheme extends ACLAScheme { - // TODO make it into a soft refrence + // TODO make it into a soft reference protected ADictionary lastDict; protected DDCScheme(IColIndex cols) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/DDCSchemeSC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/DDCSchemeSC.java index 5fde8fe42ea..8a9dc692e03 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/DDCSchemeSC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/DDCSchemeSC.java @@ -22,6 +22,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC; +import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; @@ -76,7 +77,6 @@ private void updateSparse(MatrixBlock data, int col) { final SparseBlock sb = data.getSparseBlock(); for(int i = 0; i < nRow; i++) map.increment(sb.get(i, col)); - } private void updateDense(MatrixBlock data, int col) { @@ -101,8 +101,9 @@ private void updateGeneric(MatrixBlock data, int col) { @Override public AColGroup encode(MatrixBlock data, IColIndex columns) { - validate(data, columns); + if(data.isEmpty()) + return new ColGroupEmpty(columns); final int nRow = data.getNumRows(); final AMapToData d = MapToFactory.create(nRow, map.size()); encode(data, d, cols.get(0)); @@ -131,14 +132,13 @@ private void encodeSparse(MatrixBlock data, AMapToData d, int col) { } - private void encodeDense(MatrixBlock data, AMapToData d, int col) { + private void encodeDense(final MatrixBlock data, final AMapToData d, final int col) { final int nRow = data.getNumRows(); final double[] vals = data.getDenseBlockValues(); final int nCol = data.getNumColumns(); final int max = nRow * nCol; // guaranteed lower than intmax. for(int i = 0, off = col; off < max; i++, off += nCol) d.set(i, map.getId(vals[off])); - } private void encodeGeneric(MatrixBlock data, AMapToData d, int col) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/EmptyScheme.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/EmptyScheme.java index 4e27906b919..ed7d4b6e468 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/EmptyScheme.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/EmptyScheme.java @@ -19,7 +19,6 @@ package org.apache.sysds.runtime.compress.colgroup.scheme; -import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; @@ -27,17 +26,36 @@ public class EmptyScheme extends ACLAScheme { - protected EmptyScheme(ColGroupEmpty g) { - super(g.getColIndices()); + public EmptyScheme(IColIndex idx) { + super(idx); } public static EmptyScheme create(ColGroupEmpty g) { - return new EmptyScheme(g); + return new EmptyScheme(g.getColIndices()); } @Override public ICLAScheme update(MatrixBlock data, IColIndex columns) { - throw new NotImplementedException(); + if(data.isEmpty()) // all good + return this; + + final int nRow = data.getNumRows(); + final int nColScheme = cols.size(); + for(int r = 0; r < nRow; r++) + for(int c = 0; c < nColScheme; c++) + if(data.quickGetValue(r, cols.get(c)) != 0) + return updateToHigherScheme(data, columns); + + return this; + } + + private ICLAScheme updateToHigherScheme(MatrixBlock data, IColIndex columns) { + // try with const + double[] vals = new double[cols.size()]; + for(int c = 0; c < cols.size(); c++) + vals[c] = data.quickGetValue(0, c); + + return ConstScheme.create(columns, vals).update(data, columns); } @Override @@ -45,4 +63,13 @@ public AColGroup encode(MatrixBlock data, IColIndex columns) { validate(data, columns); return new ColGroupEmpty(columns); } + + @Override + public final String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(this.getClass().getSimpleName()); + sb.append(" Cols: "); + sb.append(cols); + return sb.toString(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/RLEScheme.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/RLEScheme.java new file mode 100644 index 00000000000..1c82cf9cf34 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/RLEScheme.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.colgroup.scheme; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.ColGroupRLE; +import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; + +public class RLEScheme extends ACLAScheme { + + private static boolean messagePrinted = false; + // private final DoubleCountHashMap map; + // private final DblArrayCountHashMap map; + + public RLEScheme(IColIndex cols) { + super(cols); + if(!messagePrinted) + LOG.error("Not Implemented RLE Scheme yet"); + messagePrinted = true; + throw new NotImplementedException(); + } + + public static ICLAScheme create(ColGroupRLE g) { + return new RLEScheme(g.getColIndices()); + } + + @Override + public AColGroup encode(MatrixBlock data) { + throw new NotImplementedException(); + } + + @Override + public AColGroup encode(MatrixBlock data, IColIndex columns) { + throw new NotImplementedException(); + } + + @Override + public ICLAScheme update(MatrixBlock data) { + throw new NotImplementedException(); + } + + @Override + public ICLAScheme update(MatrixBlock data, IColIndex columns) { + throw new NotImplementedException(); + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/SDCScheme.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/SDCScheme.java new file mode 100644 index 00000000000..b4231681c94 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/SDCScheme.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.colgroup.scheme; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.ASDC; +import org.apache.sysds.runtime.compress.colgroup.ASDCZero; +import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCFOR; +import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; +import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; + +public abstract class SDCScheme extends ACLAScheme { + + // TODO make it into a soft reference + protected ADictionary lastDict; + + protected SDCScheme(IColIndex cols) { + super(cols); + } + + public static SDCScheme create(ASDC g) { + if(g instanceof ColGroupSDCFOR) + throw new NotImplementedException(); + if(g.getColIndices().size() == 1) + return new SDCSchemeSC(g); + else + return new SDCSchemeMC(g); + } + + public static SDCScheme create(ASDCZero g) { + if(g.getColIndices().size() == 1) + return new SDCSchemeSC(g); + else + return new SDCSchemeMC(g); + } + + @Override + public AColGroup encode(MatrixBlock data, IColIndex columns) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'encode'"); + } + + @Override + public ICLAScheme update(MatrixBlock data, IColIndex columns) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'update'"); + } + + protected abstract Object getDef(); + + protected abstract Object getMap(); + + @Override + public final String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(this.getClass().getSimpleName()); + sb.append("\nCols: "); + sb.append(cols); + sb.append("\nDef: "); + sb.append(getDef()); + sb.append("\nMap: "); + sb.append(getMap()); + return sb.toString(); + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/SDCSchemeMC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/SDCSchemeMC.java new file mode 100644 index 00000000000..67cb808706f --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/SDCSchemeMC.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.colgroup.scheme; + +import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.ASDC; +import org.apache.sysds.runtime.compress.colgroup.ASDCZero; +import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC; +import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; +import org.apache.sysds.runtime.compress.colgroup.ColGroupSDC; +import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; +import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; +import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; +import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; +import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; +import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; +import org.apache.sysds.runtime.compress.readers.ReaderColumnSelection; +import org.apache.sysds.runtime.compress.utils.DblArray; +import org.apache.sysds.runtime.compress.utils.DblArrayCountHashMap; +import org.apache.sysds.runtime.compress.utils.IntArrayList; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; + +public class SDCSchemeMC extends SDCScheme { + + private final DblArray emptyRow; + private final DblArray def; + private final DblArrayCountHashMap map; + + protected SDCSchemeMC(ASDC g) { + super(g.getColIndices()); + try { + this.lastDict = g.getDictionary(); + final MatrixBlockDictionary mbd = lastDict.getMBDict(this.cols.size()); + final MatrixBlock mbDict = mbd != null ? mbd.getMatrixBlock() : new MatrixBlock(1, this.cols.size(), 0.0); + final int dictRows = mbDict.getNumRows(); + final int dictCols = mbDict.getNumColumns(); + + // Read the mapping data and materialize map. + map = new DblArrayCountHashMap(dictRows * 2, dictCols); + final ReaderColumnSelection reader = ReaderColumnSelection.createReader(mbDict, // + ColIndexFactory.create(dictCols), false, 0, dictRows); + emptyRow = new DblArray(new double[dictCols]); + DblArray d = null; + int r = 0; + while((d = reader.nextRow()) != null) { + + final int row = reader.getCurrentRowIndex(); + if(row != r) { + map.increment(emptyRow, row - r); + r = row; + } + map.increment(d); + } + if(r < dictRows) { + map.increment(emptyRow, dictRows - r); + } + + def = new DblArray(g.getCommon()); + } + catch(Exception e) { + throw new DMLCompressionException(g.getDictionary().toString()); + } + } + + protected SDCSchemeMC(ASDCZero g) { + super(g.getColIndices()); + + this.lastDict = g.getDictionary(); + final MatrixBlock mbDict = lastDict.getMBDict(this.cols.size()).getMatrixBlock(); + final int dictRows = mbDict.getNumRows(); + final int dictCols = mbDict.getNumColumns(); + + // Read the mapping data and materialize map. + map = new DblArrayCountHashMap(dictRows * 2, dictCols); + final ReaderColumnSelection r = ReaderColumnSelection.createReader(mbDict, // + ColIndexFactory.create(dictCols), false, 0, dictRows); + DblArray d = null; + while((d = r.nextRow()) != null) + map.increment(d); + + emptyRow = new DblArray(new double[dictCols]); + def = new DblArray(new double[dictCols]); + } + + @Override + public AColGroup encode(MatrixBlock data, IColIndex columns) { + validate(data, columns); + final int nRow = data.getNumRows(); + if(data.isEmpty()) + return new ColGroupEmpty(columns); + // final AMapToData d = MapToFactory.create(nRow, map.size()); + + final IntArrayList offs = new IntArrayList(); + AMapToData d = encode(data, offs, cols); + + if(lastDict == null || lastDict.getNumberOfValues(columns.size()) != map.size()) + lastDict = DictionaryFactory.create(map, columns.size(), false, data.getSparsity()); + if(offs.size() == 0) + return ColGroupDDC.create(columns, lastDict, d, null); + else { + final AOffset off = OffsetFactory.createOffset(offs); + return ColGroupSDC.create(columns, nRow, lastDict, def.getData(), off, d, null); + } + } + + private AMapToData encode(MatrixBlock data, IntArrayList off, IColIndex cols) { + final int nRow = data.getNumRows(); + final ReaderColumnSelection reader = ReaderColumnSelection.createReader(// + data, cols, false, 0, nRow); + DblArray cellVals; + int emptyIdx = map.getId(emptyRow); + emptyRow.equals(def); + IntArrayList dt = new IntArrayList(); + + int r = 0; + while((cellVals = reader.nextRow()) != null) { + final int row = reader.getCurrentRowIndex(); + if(row != r) { + if(emptyIdx >= 0) { + // empty is non default. + while(r < row) { + off.appendValue(r++); + dt.appendValue(emptyIdx); + } + } + else { + r = row; + } + } + final int id = map.getId(cellVals); + if(id >= 0) { + off.appendValue(row); + dt.appendValue(id); + r++; + } + } + if(emptyIdx >= 0) { + // empty is non default. + while(r < nRow) { + off.appendValue(r++); + dt.appendValue(emptyIdx); + } + } + + AMapToData d = MapToFactory.create(off.size(), map.size()); + for(int i = 0; i < off.size(); i++) + d.set(i, dt.get(i)); + + return d; + } + + @Override + public ICLAScheme update(MatrixBlock data, IColIndex columns) { + validate(data, columns); + + if(data.isEmpty()) { + if(!def.equals(emptyRow)) + map.increment(emptyRow, data.getNumRows()); + return this; + } + final int nRow = data.getNumRows(); + final ReaderColumnSelection reader = ReaderColumnSelection.createReader(// + data, cols, false, 0, nRow); + DblArray cellVals; + final boolean defIsEmpty = emptyRow.equals(def); + + int r = 0; + while((cellVals = reader.nextRow()) != null) { + final int row = reader.getCurrentRowIndex(); + if(row != r) { + if(!defIsEmpty) + map.increment(emptyRow, row - r); + r = row; + } + final int id = map.getId(cellVals); + if(id >= 0) + map.increment(cellVals); + + } + if(!defIsEmpty) { + // empty is non default. + if(r < nRow) + map.increment(emptyRow, nRow - r); + } + + return this; + } + + protected Object getDef() { + return def; + } + + protected Object getMap() { + return map; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/SDCSchemeSC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/SDCSchemeSC.java new file mode 100644 index 00000000000..420ef14432d --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/SDCSchemeSC.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.colgroup.scheme; + +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.ASDC; +import org.apache.sysds.runtime.compress.colgroup.ASDCZero; +import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC; +import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; +import org.apache.sysds.runtime.compress.colgroup.ColGroupSDC; +import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; +import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; +import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; +import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; +import org.apache.sysds.runtime.compress.utils.DoubleCountHashMap; +import org.apache.sysds.runtime.compress.utils.IntArrayList; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; + +public class SDCSchemeSC extends SDCScheme { + + final double def; + final private DoubleCountHashMap map; + + protected SDCSchemeSC(ASDC g) { + this(g.getColIndices(), g.getCommon()[0], g.getDictionary()); + } + + protected SDCSchemeSC(ASDCZero g) { + this(g.getColIndices(), 0, g.getDictionary()); + } + + private SDCSchemeSC(IColIndex cols, double def, ADictionary lastDict) { + super(cols); + this.def = def; + this.lastDict = lastDict; + int unique = lastDict.getNumberOfValues(1); + map = new DoubleCountHashMap(unique); + + for(int i = 0; i < unique; i++) + map.increment(lastDict.getValue(i)); + } + + @Override + public AColGroup encode(MatrixBlock data, IColIndex columns) { + + validate(data, columns); + final int nRow = data.getNumRows(); + if(data.isEmpty()) + return new ColGroupEmpty(columns); + + // final AMapToData d = MapToFactory.create(nRow, map.size()); + final IntArrayList offs = new IntArrayList(); + AMapToData d = encode(data, offs, cols.get(0)); + if(lastDict == null || lastDict.getNumberOfValues(columns.size()) != map.size()) + lastDict = DictionaryFactory.create(map); + + if(offs.size() == 0) { + return ColGroupDDC.create(columns, lastDict, d, null); + } + else { + final AOffset off = OffsetFactory.createOffset(offs); + return ColGroupSDC.create(columns, nRow, lastDict, new double[] {def}, off, d, null); + } + } + + private AMapToData encode(MatrixBlock data, IntArrayList off, int col) { + + if(data.isInSparseFormat()) + return encodeSparse(data, off, col); + else if(data.getDenseBlock().isContiguous()) + return encodeDense(data, off, col); + else + return encodeGeneric(data, off, col); + } + + private AMapToData encodeSparse(MatrixBlock data, IntArrayList off, int col) { + final int nRow = data.getNumRows(); + final SparseBlock sb = data.getSparseBlock(); + // full iteration + for(int i = 0; i < nRow; i++) + if(sb.get(i, col) != def) + off.appendValue(i); + + // Only cells with non default values. + AMapToData d = MapToFactory.create(off.size(), map.size()); + for(int i = 0; i < off.size(); i++) { + int r = off.get(i); + d.set(i, map.getId(sb.get(r, col))); + } + return d; + } + + private AMapToData encodeDense(MatrixBlock data, IntArrayList off, int col) { + final int nRow = data.getNumRows(); + final double[] vals = data.getDenseBlockValues(); + final int nCol = data.getNumColumns(); + final int max = nRow * nCol; // guaranteed lower than intmax. + // full iteration + for(int i = 0, o = col; o < max; i++, o += nCol) { + if(vals[o] != def) + off.appendValue(i); + } + + // Only cells with non default values. + AMapToData d = MapToFactory.create(off.size(), map.size()); + for(int i = 0; i < off.size(); i++) { + int o = off.get(i) * nCol + col; + d.set(i, map.getId(vals[o])); + } + return d; + } + + private AMapToData encodeGeneric(MatrixBlock data, IntArrayList off, int col) { + final int nRow = data.getNumRows(); + final DenseBlock db = data.getDenseBlock(); + + // full iteration + for(int i = 0; i < nRow; i++) { + final double[] c = db.values(i); + final int o = db.pos(i) + col; + if(c[o] != def) + off.appendValue(i); + } + + // Only cells with non default values. + AMapToData d = MapToFactory.create(off.size(), map.size()); + for(int i = 0; i < off.size(); i++) { + final int of = off.get(i); + final int o = db.pos(of) + col; + d.set(i, map.getId(db.values(of)[o])); + } + return d; + } + + @Override + public ICLAScheme update(MatrixBlock data, IColIndex columns) { + validate(data, columns); + + final int col = columns.get(0); + if(data.isEmpty()) { + if(def != 0.0) + map.increment(0.0, data.getNumRows()); + } + else if(data.isInSparseFormat()) + updateSparse(data, col); + else if(data.getDenseBlock().isContiguous()) + updateDense(data, col); + else + updateGeneric(data, col); + return this; + } + + private void updateSparse(MatrixBlock data, int col) { + final int nRow = data.getNumRows(); + final SparseBlock sb = data.getSparseBlock(); + for(int i = 0; i < nRow; i++) { + final double v = sb.get(i, col); + if(v != def) + map.increment(v); + } + } + + private void updateDense(MatrixBlock data, int col) { + final int nRow = data.getNumRows(); + final double[] vals = data.getDenseBlockValues(); + final int nCol = data.getNumColumns(); + final int max = nRow * nCol; // guaranteed lower than intmax. + for(int off = col; off < max; off += nCol) { + final double v = vals[off]; + if(v != def) + map.increment(v); + } + + } + + private void updateGeneric(MatrixBlock data, int col) { + final int nRow = data.getNumRows(); + final DenseBlock db = data.getDenseBlock(); + for(int i = 0; i < nRow; i++) { + final double[] c = db.values(i); + final int off = db.pos(i) + col; + final double v = c[off]; + if(v != def) + map.increment(v); + } + } + + protected Object getDef() { + return def; + } + + protected Object getMap() { + return map; + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/SchemeFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/SchemeFactory.java index ebdfab0a0b5..d9e558b9554 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/SchemeFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/SchemeFactory.java @@ -26,16 +26,16 @@ public class SchemeFactory { public static ICLAScheme create(IColIndex columns, CompressionType type) { switch(type) { - case CONST: - return ConstScheme.create(columns); case DDC: return DDCScheme.create(columns); case DDCFOR: break; case DeltaDDC: break; + case CONST: + // const is automatically empty if no data is provided. case EMPTY: - break; + return new EmptyScheme(columns); case LinearFunctional: break; case OLE: diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstExact.java b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstExact.java index 63af720223a..6483eba1048 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstExact.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstExact.java @@ -20,6 +20,7 @@ package org.apache.sysds.runtime.compress.estim; import org.apache.sysds.runtime.compress.CompressionSettings; +import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.estim.encoding.EmptyEncoding; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; @@ -39,7 +40,7 @@ public ComEstExact(MatrixBlock data, CompressionSettings compSettings) { public CompressedSizeInfoColGroup getColGroupInfo(IColIndex colIndexes, int estimate, int nrUniqueUpperBound) { final IEncode map = EncodingFactory.createFromMatrixBlock(_data, _cs.transposed, colIndexes); if(map instanceof EmptyEncoding) - return new CompressedSizeInfoColGroup(colIndexes, getNumRows()); + return new CompressedSizeInfoColGroup(colIndexes, getNumRows(), CompressionType.EMPTY); return getFacts(map, colIndexes); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java index 0b7e9050605..bfc5ffe9458 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java @@ -23,6 +23,7 @@ import java.util.Random; import org.apache.sysds.runtime.compress.CompressionSettings; +import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; @@ -77,7 +78,7 @@ public CompressedSizeInfoColGroup getColGroupInfo(IColIndex colIndexes, int esti if(_data.isEmpty() || (nnzCols != null && colIndexes.size() == 1 && nnzCols[colIndexes.get(0)] == 0) || (_cs.transposed && colIndexes.size() == 1 && _data.isInSparseFormat() && _data.getSparseBlock().isEmpty(colIndexes.get(0)))) - return new CompressedSizeInfoColGroup(colIndexes, getNumRows()); + return new CompressedSizeInfoColGroup(colIndexes, getNumRows(), CompressionType.EMPTY); final IEncode map = EncodingFactory.createFromMatrixBlock(_sample, _transposed, colIndexes); return extractInfo(map, colIndexes, maxDistinct); diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java index 49a30f65b09..1168147b3d2 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java @@ -79,7 +79,8 @@ public CompressedSizeInfoColGroup(IColIndex cols, EstimationFactors facts, long _sizes.put(bestCompressionType, _minSize); } - public CompressedSizeInfoColGroup(IColIndex columns, EstimationFactors facts, long minSize, CompressionType bestCompression, IEncode map){ + public CompressedSizeInfoColGroup(IColIndex columns, EstimationFactors facts, long minSize, + CompressionType bestCompression, IEncode map) { _cols = columns; _facts = facts; _minSize = minSize; @@ -110,22 +111,30 @@ public CompressedSizeInfoColGroup(IColIndex columns, EstimationFactors facts, } /** - * Create empty. + * Create empty or const. * * @param columns columns * @param nRows number of rows + * @param ct The type intended either Empty or Const */ - public CompressedSizeInfoColGroup(IColIndex columns, int nRows) { + public CompressedSizeInfoColGroup(IColIndex columns, int nRows, CompressionType ct) { _cols = columns; _facts = new EstimationFactors(0, nRows); - _sizes = new EnumMap<>(CompressionType.class); - final CompressionType ct = CompressionType.EMPTY; - _sizes.put(ct, (double) ColGroupSizes.estimateInMemorySizeEMPTY(columns.size(), columns.isContiguous())); + switch(ct) { + case EMPTY: + _sizes.put(ct, (double) ColGroupSizes.estimateInMemorySizeEMPTY(columns.size(), columns.isContiguous())); + break; + case CONST: + _sizes.put(ct, + (double) ColGroupSizes.estimateInMemorySizeCONST(columns.size(), columns.isContiguous(), 1.0, false)); + break; + default: + throw new DMLCompressionException("Invalid instantiation of const Cost"); + } _bestCompressionType = ct; _minSize = _sizes.get(ct); _map = null; - } public double getCompressionSize(CompressionType ct) { @@ -213,11 +222,11 @@ private static EnumMap calculateCompressionSizes(IColIn } public boolean isEmpty() { - return _bestCompressionType == CompressionType.EMPTY; + return _bestCompressionType == CompressionType.EMPTY || _sizes.containsKey(CompressionType.EMPTY); } public boolean isConst() { - return _bestCompressionType == CompressionType.CONST; + return _bestCompressionType == CompressionType.CONST || _sizes.containsKey(CompressionType.CONST); } private static double getCompressionSize(IColIndex cols, CompressionType ct, EstimationFactors fact) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java index db1905eccc7..4da9d8462da 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java @@ -141,7 +141,7 @@ protected DenseEncoding combineDense(final DenseEncoding other) { final AMapToData ret = MapToFactory.create(size, maxUnique); - if(maxUnique > size) { + if(maxUnique > size && maxUnique > 2048) { // aka there is more maxUnique than rows. final Map m = new HashMap<>(size); return combineDenseWithHashMap(lm, rm, size, nVL, ret, m); diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibAggTernaryOp.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibAggTernaryOp.java new file mode 100644 index 00000000000..5d38c25ece3 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibAggTernaryOp.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.lib; + +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; +import org.apache.sysds.runtime.functionobjects.KahanPlus; +import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.functionobjects.ReduceAll; +import org.apache.sysds.runtime.functionobjects.ReduceRow; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator; + +public final class CLALibAggTernaryOp { + private static final Log LOG = LogFactory.getLog(CLALibAggTernaryOp.class.getName()); + + private final MatrixBlock m1; + private final MatrixBlock m2; + private final MatrixBlock m3; + private final MatrixBlock ret; + private final AggregateTernaryOperator op; + private final boolean inCP; + private static boolean warned = false; + + public static MatrixBlock agg(MatrixBlock m1, MatrixBlock m2, MatrixBlock m3, MatrixBlock ret, + AggregateTernaryOperator op, boolean inCP) { + + int rl = (op.indexFn instanceof ReduceRow) ? 2 : 1; + int cl = (op.indexFn instanceof ReduceRow) ? m1.getNumColumns() : 2; + if(ret == null) + ret = new MatrixBlock(rl, cl, false); + else + ret.reset(rl, cl, false); + ret = new CLALibAggTernaryOp(m1, m2, m3, ret, op, inCP).exec(); + + return ret; + } + + private CLALibAggTernaryOp(MatrixBlock m1, MatrixBlock m2, MatrixBlock m3, MatrixBlock ret, + AggregateTernaryOperator op, boolean inCP) { + this.m1 = m1; + this.m2 = m2; + this.m3 = m3; + this.ret = ret; + this.op = op; + this.inCP = inCP; + } + + private MatrixBlock exec() { + if(op.indexFn instanceof ReduceAll && op.aggOp.increOp.fn instanceof KahanPlus && + op.binaryFn instanceof Multiply) { + // early abort if if anyEmpty. + if(m1.isEmptyBlock(false) || m2.isEmptyBlock(false) || m3 != null && m3.isEmptyBlock(false)) { + if(op.aggOp.existsCorrection() && inCP) + ret.dropLastRowsOrColumns(op.aggOp.correction); + return ret; + } + + // if any is constant. + if(isConst(m1)) { + double v = m1.quickGetValue(0, 0); + if(v == 1.0) + return new CLALibAggTernaryOp(m2, m3, null, ret, op, inCP).exec(); + } + } + return fallBack(); + } + + private static boolean isConst(MatrixBlock m) { + if(m != null && m instanceof CompressedMatrixBlock) { + List gs = ((CompressedMatrixBlock) m).getColGroups(); + return gs.size() == 1 && gs.get(0) instanceof ColGroupConst; + } + return false; + } + + private MatrixBlock fallBack() { + warnDecompression(); + MatrixBlock m1UC = CompressedMatrixBlock.getUncompressed(m1); + MatrixBlock m2UC = CompressedMatrixBlock.getUncompressed(m2); + MatrixBlock m3UC = CompressedMatrixBlock.getUncompressed(m3); + + MatrixBlock ret2 = MatrixBlock.aggregateTernaryOperations(m1UC, m2UC, m3UC, ret, op, inCP); + if(ret2.getNumRows() == 0 || ret2.getNumColumns() == 0) + throw new DMLCompressionException("Invalid output"); + return ret2; + } + + private void warnDecompression() { + + if(!warned) { + + boolean m1C = m1 instanceof CompressedMatrixBlock; + boolean m2C = m2 instanceof CompressedMatrixBlock; + boolean m3C = m3 instanceof CompressedMatrixBlock; + StringBuilder sb = new StringBuilder(120); + + sb.append("aggregateTernaryOperations "); + sb.append(op.aggOp.getClass().getSimpleName()); + sb.append(" "); + sb.append(op.indexFn.getClass().getSimpleName()); + sb.append(" "); + sb.append(op.aggOp.increOp.fn.getClass().getSimpleName()); + sb.append(" "); + sb.append(op.binaryFn.getClass().getSimpleName()); + sb.append(" m1,m2,m3 "); + sb.append(m1C); + sb.append(" "); + sb.append(m2C); + sb.append(" "); + sb.append(m3C); + + LOG.warn(sb.toString()); + warned = true; + } + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java index 801ef893dcf..050db80bc17 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java @@ -80,17 +80,16 @@ public static List combine(CompressedMatrixBlock cmb, int k) { } public static List combine(CompressedMatrixBlock cmb, CompressedSizeInfo csi, ExecutorService pool) { - List input = cmb.getColGroups(); - boolean filterFor = CLALibUtils.shouldFilterFOR(input); + + final boolean filterFor = CLALibUtils.shouldFilterFOR(input); double[] c = filterFor ? new double[cmb.getNumColumns()] : null; if(filterFor) input = CLALibUtils.filterFOR(input, c); List> combinations = new ArrayList<>(); - for(CompressedSizeInfoColGroup gi : csi.getInfo()) { + for(CompressedSizeInfoColGroup gi : csi.getInfo()) combinations.add(findGroupsInIndex(gi.getColumns(), input)); - } List ret = new ArrayList<>(); if(filterFor) @@ -99,16 +98,15 @@ public static List combine(CompressedMatrixBlock cmb, CompressedSizeI else for(List combine : combinations) ret.add(combineN(combine)); - return ret; } public static List findGroupsInIndex(IColIndex idx, List groups) { List ret = new ArrayList<>(); - for(AColGroup g : groups) { + for(AColGroup g : groups) if(g.getColIndices().containsAny(idx)) ret.add(g); - } + return ret; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScheme.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScheme.java new file mode 100644 index 00000000000..5abbd9bd0f4 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScheme.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.lib; + +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType; +import org.apache.sysds.runtime.compress.colgroup.indexes.SingleIndex; +import org.apache.sysds.runtime.compress.colgroup.scheme.CompressionScheme; +import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; +import org.apache.sysds.runtime.compress.colgroup.scheme.SchemeFactory; + +public class CLALibScheme { + + public static CompressionScheme getScheme(CompressedMatrixBlock cmb) { + return CompressionScheme.getScheme(cmb); + } + + /** + * Generate a scheme with the given type of columnGroup and number of columns in each group + * + * @param type The type of encoding to use + * @param nCols The number of columns + * @return A scheme to generate. + */ + public static CompressionScheme genScheme(CompressionType type, int nCols) { + ICLAScheme[] encodings = new ICLAScheme[nCols]; + for(int i = 0; i < nCols; i++) + encodings[i] = SchemeFactory.create(new SingleIndex(i), type); + return new CompressionScheme(encodings); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSlice.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSlice.java index a99142ec005..9373a036062 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSlice.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSlice.java @@ -31,6 +31,9 @@ import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; +import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; +import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; +import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.CommonThreadPool; @@ -143,6 +146,7 @@ private static MatrixBlock sliceRowsDecompress(CompressedMatrixBlock cmb, int rl private static MatrixBlock sliceRowsCompressed(CompressedMatrixBlock cmb, int rl, int ru) { final List groups = cmb.getColGroups(); final List newColGroups = new ArrayList<>(groups.size()); + final List emptyGroups = new ArrayList<>(); final int rue = ru + 1; final CompressedMatrixBlock ret = new CompressedMatrixBlock(rue - rl, cmb.getNumColumns()); @@ -151,11 +155,18 @@ private static MatrixBlock sliceRowsCompressed(CompressedMatrixBlock cmb, int rl final AColGroup slice = grp.sliceRows(rl, rue); if(slice != null) newColGroups.add(slice); + else + emptyGroups.add(grp.getColIndices()); } if(newColGroups.size() == 0) return new MatrixBlock(rue - rl, cmb.getNumColumns(), 0.0); + if(!emptyGroups.isEmpty()){ + IColIndex empties = ColIndexFactory.combineIndexes(emptyGroups); + newColGroups.add(new ColGroupEmpty(empties)); + } + ret.allocateColGroupList(newColGroups); ret.recomputeNonZeros(); ret.setOverlapping(cmb.isOverlapping()); diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java index 1dd76483c61..178c13ad297 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java @@ -156,7 +156,7 @@ private static MatrixBlock combine(final Map m, fina final int c = cols.next(); if(colTypes[c + off] != t) { LOG.warn("Not supported different types of column groups to combine." - + "Falling back to decompression of all blocks"); + + "Falling back to decompression of all blocks " + t + " vs " + colTypes[c + off]); return combineViaDecompression(m, rlen, clen, blen, k); } } @@ -192,7 +192,6 @@ private static MatrixBlock combineColumnGroups(final Map filterFOR(List groups, double[] cons for(AColGroup g : groups) if(g instanceof IFrameOfReferenceGroup) filteredGroups.add(((IFrameOfReferenceGroup) g).extractCommon(constV)); + else + filteredGroups.add(g); return filteredGroups; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleCountHashMap.java b/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleCountHashMap.java index c496b447f41..702072cb338 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleCountHashMap.java +++ b/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleCountHashMap.java @@ -35,7 +35,7 @@ public class DoubleCountHashMap { private Bucket[] _data = null; public DoubleCountHashMap(int init_capacity) { - _data = new Bucket[(Util.getPow2(init_capacity)/2) + 7]; + _data = new Bucket[(Util.getPow2(init_capacity) / 2) + 7]; // _data = new Bucket[(Util.getPow2(init_capacity)) ]; _size = 0; } @@ -70,7 +70,7 @@ public final int increment(final double key) { return l.v.id; } else - l = l.n; + l = l.n; } return addNewBucket(ix, key); } @@ -84,7 +84,7 @@ public final int increment(final double key, final int count) { return l.v.id; } else - l = l.n; + l = l.n; } return addNewBucket(ix, key); } @@ -106,38 +106,39 @@ private int addNewBucket(final int ix, final double key) { * @return count on key */ public int get(double key) { - try{ + try { int ix = hashIndex(key); Bucket l = _data[ix]; while(!(l.v.key == key)) l = l.n; - + return l.v.count; - } catch( Exception e){ + } + catch(Exception e) { if(Double.isNaN(key)) return get(0.0); throw e; } } - /** + /** * Get the ID behind the key, if it does not exist -1 is returned. * * @param key The key array * @return The Id or -1 */ public int getId(double key) { - try{ + try { int ix = hashIndex(key); Bucket l = _data[ix]; while(!(l.v.key == key)) l = l.n; - return l.v.id; - } catch( Exception e){ + } + catch(Exception e) { if(Double.isNaN(key)) return get(0.0); - throw e; + throw new RuntimeException("Failed to getKey : " + key + " in " + this, e); } } @@ -173,12 +174,11 @@ public void replaceWithUIDs() { } } - public void replaceWithUIDsNoZero() { int i = 0; for(Bucket e : _data) { while(e != null) { - if(e.v.key != 0) + if(e.v.key != 0) e.v.count = i++; e = e.n; } @@ -214,12 +214,12 @@ public int[] getUnorderedCountsAndReplaceWithUIDsWithout0() { return counts; } - public double getMostFrequent(){ + public double getMostFrequent() { double f = 0; int fq = 0; - for(Bucket e: _data){ - while(e != null){ - if(e.v.count > fq){ + for(Bucket e : _data) { + while(e != null) { + if(e.v.count > fq) { fq = e.v.count; f = e.v.key; } @@ -261,7 +261,7 @@ public double[] getDictionary() { private final int hashIndex(final double key) { // Option 1 ... conflict on 1 vs -1 final long bits = Double.doubleToLongBits(key); - return Math.abs((int)(bits ^ (bits >>> 32)) % _data.length); + return Math.abs((int) (bits ^ (bits >>> 32)) % _data.length); } // private static int indexFor(int h, int length) { @@ -288,7 +288,7 @@ public String toString() { @Override public String toString() { StringBuilder sb = new StringBuilder(); - sb.append(this.getClass().getSimpleName() + this.hashCode()); + sb.append(this.getClass().getSimpleName()); for(int i = 0; i < _data.length; i++) if(_data[i] != null) sb.append(", " + _data[i]); diff --git a/src/main/java/org/apache/sysds/runtime/compress/utils/IntArrayList.java b/src/main/java/org/apache/sysds/runtime/compress/utils/IntArrayList.java index a78e73dac93..6a435c49de6 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/utils/IntArrayList.java +++ b/src/main/java/org/apache/sysds/runtime/compress/utils/IntArrayList.java @@ -21,16 +21,17 @@ import java.util.Arrays; +import org.apache.sysds.runtime.compress.DMLCompressionException; + public class IntArrayList { private static final int INIT_CAPACITY = 4; private static final int RESIZE_FACTOR = 2; - private int[] _data = null; + private int[] _data; private int _size; public IntArrayList() { - _data = null; - _size = 0; + this(INIT_CAPACITY); } public IntArrayList(int initialSize) { @@ -39,6 +40,8 @@ public IntArrayList(int initialSize) { } public IntArrayList(int[] values) { + if(values == null) + throw new DMLCompressionException("Invalid initialization of IntArrayList"); _data = values; _size = values.length; } @@ -49,10 +52,7 @@ public int size() { public void appendValue(int value) { // allocate or resize array if necessary - if(_data == null) { - _data = new int[INIT_CAPACITY]; - } - else if(_size + 1 >= _data.length) + if(_size + 1 >= _data.length) resize(); // append value @@ -71,10 +71,7 @@ public int[] extractValues() { } public int get(int index) { - if(_data != null) - return _data[index]; - else - throw new RuntimeException("invalid index to get"); + return _data[index]; } public int[] extractValues(boolean trim) { @@ -94,13 +91,14 @@ private void resize() { @Override public String toString() { StringBuilder sb = new StringBuilder(); - + if(_size == 0) + return "[]"; sb.append("["); int i = 0; for(; i < _size - 1; i++) - sb.append(_data[i] + ","); - - sb.append(_data[i] + "]"); + sb.append(_data[i]).append(", "); + sb.append(_data[i]); + sb.append("]"); return sb.toString(); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateTernaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateTernaryCPInstruction.java index 01c966314ec..e93ab5aea53 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateTernaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateTernaryCPInstruction.java @@ -68,7 +68,7 @@ public void processInstruction(ExecutionContext ec) { AggregateTernaryOperator ab_op = (AggregateTernaryOperator) _optr; validateInput(matBlock1, matBlock2, matBlock3, ab_op); - MatrixBlock ret = matBlock1 + MatrixBlock ret = MatrixBlock .aggregateTernaryOperations(matBlock1, matBlock2, matBlock3, new MatrixBlock(), ab_op, true); // release inputs/outputs diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateTernarySPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateTernarySPInstruction.java index 29134a0dc00..6d179fc8fc3 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateTernarySPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateTernarySPInstruction.java @@ -138,7 +138,7 @@ public Tuple2 call(Tuple2(new MatrixIndexes(1, ix.getColumnIndex()), - in1.aggregateTernaryOperations(in1, in2, in3, new MatrixBlock(), _aggop, false)); + MatrixBlock.aggregateTernaryOperations(in1, in2, in3, new MatrixBlock(), _aggop, false)); } } @@ -164,7 +164,7 @@ public Tuple2 call(Tuple2(new MatrixIndexes(1, ix.getColumnIndex()), - in1.aggregateTernaryOperations(in1, in2, null, new MatrixBlock(), _aggop, false)); + MatrixBlock.aggregateTernaryOperations(in1, in2, null, new MatrixBlock(), _aggop, false)); } } } diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java index 2378f73169c..61eecf251f9 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java @@ -27,6 +27,8 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.CorrectionLocationType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput; @@ -87,7 +89,7 @@ * TODO next opcode extensions: a+, colindexmax */ public class LibMatrixAgg { - // private static final Log LOG = LogFactory.getLog(LibMatrixAgg.class.getName()); + protected static final Log LOG = LogFactory.getLog(LibMatrixAgg.class.getName()); //internal configuration parameters private static final boolean NAN_AWARENESS = false; @@ -512,8 +514,11 @@ public static MatrixBlock aggregateTernary(MatrixBlock in1, MatrixBlock in2, Mat public static MatrixBlock aggregateTernary(MatrixBlock in1, MatrixBlock in2, MatrixBlock in3, MatrixBlock ret, AggregateTernaryOperator op, int k) { //fall back to sequential version if necessary - if( k <= 1 || in1.nonZeros+in2.nonZeros < PAR_NUMCELL_THRESHOLD1 || in1.rlen <= k/2 - || (!(op.indexFn instanceof ReduceCol) && ret.clen*8*k > PAR_INTERMEDIATE_SIZE_THRESHOLD) ) { + if( k <= 1 + || in1.nonZeros+in2.nonZeros < PAR_NUMCELL_THRESHOLD1 + || in1.rlen <= k/2 + // || (!(op.indexFn instanceof ReduceCol) && ret.clen*8*k > PAR_INTERMEDIATE_SIZE_THRESHOLD) + ) { return aggregateTernary(in1, in2, in3, ret, op); } @@ -636,7 +641,7 @@ public static boolean satisfiesMultiThreadingConstraints(MatrixBlock in, MatrixB && in.nonZeros > (sharedTP ? PAR_NUMCELL_THRESHOLD2 : PAR_NUMCELL_THRESHOLD1); } - public static boolean satisfiesMultiThreadingConstraints(MatrixBlock in,int k) { + public static boolean satisfiesMultiThreadingConstraints(MatrixBlock in, int k) { boolean sharedTP = (InfrastructureAnalyzer.getLocalParallelism() == k); return k > 1 && in.rlen > (sharedTP ? k/8 : k/2) && in.nonZeros > (sharedTP ? PAR_NUMCELL_THRESHOLD2 : PAR_NUMCELL_THRESHOLD1); diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java index ebf4a6aeb3b..5aaf0cd46a5 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java @@ -51,6 +51,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.compress.lib.CLALibAggTernaryOp; import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; @@ -967,11 +968,21 @@ public MatrixBlock colMax() { * @return the maximum value of all values in the matrix */ public double max() { - MatrixBlock out = new MatrixBlock(1, 1, false); - LibMatrixAgg.aggregateUnaryMatrix(this, out, - InstructionUtils.parseBasicAggregateUnaryOperator("uamax", 1)); + AggregateUnaryOperator op =InstructionUtils.parseBasicAggregateUnaryOperator("uamax", 1); + MatrixBlock out = aggregateUnaryOperations(op, null, 1000, null, true); return out.quickGetValue(0, 0); } + + /** + * Wrapper method for reduceall-max of a matrix. + * + * @param k the parallelization degree + * @return the maximum value of all values in the matrix + */ + public MatrixBlock max(int k){ + AggregateUnaryOperator op = InstructionUtils.parseBasicAggregateUnaryOperator("uamax", k); + return aggregateUnaryOperations(op, null, 1000, null, true); + } /** * Wrapper method for reduceall-sum of a matrix. @@ -983,6 +994,17 @@ public double sum() { return sumWithFn(kplus); } + /** + * Wrapper method for reduceall-sum of a matrix parallel + * + * @param k parallelization degree + * @return Sum of the values in the matrix. + */ + public MatrixBlock sum(int k) { + AggregateUnaryOperator op = InstructionUtils.parseBasicAggregateUnaryOperator("uak+", k); + return aggregateUnaryOperations(op, null, 1000, null, true); + } + /** * Wrapper method for single threaded reduceall-colSum of a matrix. * @@ -4982,15 +5004,11 @@ private void checkAggregateBinaryOperationsCommon(MatrixBlock m1, MatrixBlock m2 throw new DMLRuntimeException("Invalid aggregateBinaryOperatio: one of either input should be this"); } - public MatrixBlock aggregateTernaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock m3, MatrixBlock ret, + public static MatrixBlock aggregateTernaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock m3, MatrixBlock ret, AggregateTernaryOperator op, boolean inCP) { - if(m1 instanceof CompressedMatrixBlock) - m1 = ((CompressedMatrixBlock) m1).getUncompressed("Aggregate Ternary Operator arg1 " + op.getClass().getSimpleName(), op.getNumThreads()); - if(m2 instanceof CompressedMatrixBlock) - m2 = ((CompressedMatrixBlock) m2).getUncompressed("Aggregate Ternary Operator arg2 " + op.getClass().getSimpleName(), op.getNumThreads()); - if(m3 instanceof CompressedMatrixBlock) - m3 = ((CompressedMatrixBlock) m3).getUncompressed("Aggregate Ternary Operator arg3 " + op.getClass().getSimpleName(), op.getNumThreads()); - + if(m1 instanceof CompressedMatrixBlock || m2 instanceof CompressedMatrixBlock || m3 instanceof CompressedMatrixBlock) + return CLALibAggTernaryOp.agg(m1, m2, m3, ret, op, inCP); + //create output matrix block w/ corrections int rl = (op.indexFn instanceof ReduceRow) ? 2 : 1; int cl = (op.indexFn instanceof ReduceRow) ? m1.clen : 2; @@ -5626,15 +5644,15 @@ public static MatrixBlock randOperations(int rows, int cols, double sparsity, do /** * Function to generate the random matrix with specified dimensions (block sizes are not specified). - * - * @param rows number of rows - * @param cols number of columns + * + * @param rows number of rows + * @param cols number of columns * @param sparsity sparsity as a percentage - * @param min minimum value - * @param max maximum value - * @param pdf pdf - * @param seed random seed - * @param k ? + * @param min minimum value + * @param max maximum value + * @param pdf pdf + * @param seed random seed + * @param k The number of threads in the operation * @return matrix block */ public static MatrixBlock randOperations(int rows, int cols, double sparsity, double min, double max, String pdf, long seed, int k) { @@ -5663,7 +5681,7 @@ public static MatrixBlock randOperations(RandomMatrixGenerator rgen, long seed) * * @param rgen random matrix generator * @param seed seed value - * @param k ? + * @param k The number of threads to use in the operation * @return matrix block */ public static MatrixBlock randOperations(RandomMatrixGenerator rgen, long seed, int k) { diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedCustomTests.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedCustomTests.java index 657551fc02c..60c9342d73a 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/CompressedCustomTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedCustomTests.java @@ -324,8 +324,8 @@ public void createUncompressedCompressedMatrixBlockTest() { TestUtils.compareMatricesBitAvgDistance(mb, mb2, 0, 0); } - @Test(expected = DMLCompressionException.class) - public void invalidIfNnzNotSet() { + @Test + public void notInvalidIfNnzNotSet() { MatrixBlock mb = TestUtils.generateTestMatrixBlock(32, 42, 32, 123, 0.2, 2135); mb.setNonZeros(-23L); CompressedMatrixBlockFactory.compress(mb); diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedLoggingTests.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedLoggingTests.java index 650ad8091f4..28b012cbf5d 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/CompressedLoggingTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedLoggingTests.java @@ -36,6 +36,7 @@ import org.apache.sysds.test.LoggingUtils; import org.apache.sysds.test.LoggingUtils.TestAppender; import org.apache.sysds.test.TestUtils; +import org.junit.Ignore; import org.junit.Test; public class CompressedLoggingTests { @@ -370,6 +371,7 @@ public void compressedLoggingTest_recompress() { } @Test + @Ignore public void compressedLoggingTest_AbortEnd() { final TestAppender appender = LoggingUtils.overwrite(); @@ -381,8 +383,9 @@ public void compressedLoggingTest_AbortEnd() { CompressionSettingsBuilder sb = new CompressionSettingsBuilder(); sb.setMaxSampleSize(ss); sb.setMinimumSampleSize(ss); - CompressedMatrixBlockFactory.compress(mb, sb).getLeft(); + MatrixBlock cmb = CompressedMatrixBlockFactory.compress(mb, sb).getLeft(); final List log = LoggingUtils.reinsert(appender); + LOG.error(cmb); for(LoggingEvent l : log) { // LOG.error(l.getMessage()); if(l.getMessage().toString().contains("Abort block compression")) @@ -449,7 +452,6 @@ public void compressionSettingsEstimationType() { } } - @Test public void compressionSettingsFull() { final TestAppender appender = LoggingUtils.overwrite(); @@ -462,7 +464,7 @@ public void compressionSettingsFull() { if(l.getMessage().toString().contains("Estimation Type")) fail("Contained estimationType"); } - + } catch(Exception e) { e.printStackTrace(); @@ -473,4 +475,31 @@ public void compressionSettingsFull() { LoggingUtils.reinsert(appender); } } + + @Test + public void compressedLoggingTest_NNzNotSet() { + final TestAppender appender = LoggingUtils.overwrite(); + + try { + Logger.getLogger(CompressedMatrixBlockFactory.class).setLevel(Level.WARN); + MatrixBlock mb = TestUtils.generateTestMatrixBlock(1000, 5, 1, 1, 0.5, 235); + mb.setNonZeros(-1); + MatrixBlock m2 = CompressedMatrixBlockFactory.compress(mb).getLeft(); + TestUtils.compareMatrices(mb, m2, 0.0); + final List log = LoggingUtils.reinsert(appender); + for(LoggingEvent l : log) { + if(l.getMessage().toString().contains("Recomputing non-zeros")) + return; + } + fail("NonZeros not set warning not printed"); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + finally { + Logger.getLogger(CompressedMatrixBlockFactory.class).setLevel(Level.WARN); + LoggingUtils.reinsert(appender); + } + } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java index 754090fe2dc..3c036a92f7a 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java @@ -418,9 +418,11 @@ public void testAggregateTernaryOperation() { MatrixBlock m2 = new MatrixBlock(nrow, ncol, 13.0); MatrixBlock m3 = new MatrixBlock(nrow, ncol, 14.0); - MatrixBlock ret1 = cmb.aggregateTernaryOperations(cmb, m2, m3, null, op, true); - ucRet = mb.aggregateTernaryOperations(mb, m2, m3, ucRet, op, true); + MatrixBlock ret1 = MatrixBlock.aggregateTernaryOperations(cmb, m2, m3, null, op, true); + ucRet = MatrixBlock.aggregateTernaryOperations(mb, m2, m3, ucRet, op, true); + // LOG.error(ret1); + // LOG.error(ucRet); compareResultMatrices(ucRet, ret1, 1); } catch(Exception e) { @@ -445,8 +447,8 @@ public void testAggregateTernaryOperationZero() { MatrixBlock m2 = new MatrixBlock(nrow, ncol, 0); MatrixBlock m3 = new MatrixBlock(nrow, ncol, 14.0); - MatrixBlock ret1 = cmb.aggregateTernaryOperations(cmb, m2, m3, null, op, true); - ucRet = mb.aggregateTernaryOperations(mb, m2, m3, ucRet, op, true); + MatrixBlock ret1 = MatrixBlock.aggregateTernaryOperations(cmb, m2, m3, null, op, true); + ucRet = MatrixBlock.aggregateTernaryOperations(mb, m2, m3, ucRet, op, true); compareResultMatrices(ucRet, ret1, 1); } diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java index 69bcd21a4c8..7b791180e7a 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java @@ -50,6 +50,7 @@ import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.colgroup.indexes.IIterate; import org.apache.sysds.runtime.compress.cost.ACostEstimate; import org.apache.sysds.runtime.compress.cost.CostEstimatorBuilder; import org.apache.sysds.runtime.compress.cost.CostEstimatorFactory; @@ -684,6 +685,8 @@ public void testLeftMatrixMatrixMultiplicationTransposed(MatrixBlock matrix, boo if(compressMatrix && !(compMatrix instanceof CompressedMatrixBlock)) return; // Early termination since the test does not test what we wanted. + compareResultMatrices(matrix, compMatrix, 0.0); + // Make Operator AggregateBinaryOperator abopSingle = InstructionUtils.getMatMultOperator(1); @@ -1054,6 +1057,7 @@ public void testSlice(int rl, int ru, int cl, int cu) { public void testCompressAgain() { try { TestUtils.assertEqualColsAndRows(mb, cmb); + compareResultMatrices(mb, cmb, 1); MatrixBlock cmba = CompressedMatrixBlockFactory.compress(cmb, _k).getLeft(); compareResultMatrices(mb, cmba, 1); } @@ -1139,14 +1143,14 @@ public void testRandOperationsInPlace() { protected void compareResultMatrices(MatrixBlock expected, MatrixBlock result, double toleranceMultiplier) { TestUtils.assertEqualColsAndRows(expected, result); - if(expected instanceof CompressedMatrixBlock) + if(expected instanceof CompressedMatrixBlock) { + verifyContainsAllColumns((CompressedMatrixBlock) expected); expected = ((CompressedMatrixBlock) expected).decompress(); - if(result instanceof CompressedMatrixBlock) + } + if(result instanceof CompressedMatrixBlock) { + verifyContainsAllColumns((CompressedMatrixBlock) result); result = ((CompressedMatrixBlock) result).decompress(); - - if(result.getNonZeros() < expected.getNonZeros()) - fail("Nonzero is to low guarantee at least equal or higher" + result.getNonZeros() + " vs " - + expected.getNonZeros()); + } if(_cs != null && _cs.lossy) TestUtils.compareMatricesPercentageDistance(expected, result, 0.25, 0.83, bufferedToString); @@ -1159,15 +1163,45 @@ else if(OverLapping.effectOnOutput(overlappingType)) else TestUtils.compareMatricesBitAvgDistance(expected, result, (long) (27000 * toleranceMultiplier), (long) (1024 * toleranceMultiplier), bufferedToString); + + if(result.getNonZeros() < expected.getNonZeros()) + fail("Nonzero is to low guarantee at least equal or higher " + result.getNonZeros() + " vs " + + expected.getNonZeros()); + + } + + protected void verifyContainsAllColumns(CompressedMatrixBlock mb) { + boolean[] cols = new boolean[mb.getNumColumns()]; + List groups = mb.getColGroups(); + + for(int i = 0; i < groups.size(); i++) { + AColGroup g = groups.get(i); + IColIndex idx = g.getColIndices(); + IIterate it = idx.iterator(); + while(it.hasNext()) { + cols[it.v()] = true; + it.next(); + } + } + + for(int i = 0; i < cols.length; i++) { + if(!cols[i]) + fail("Invalid constructed compression is missing column: " + i); + } + } protected void compareResultMatricesPercentDistance(MatrixBlock expected, MatrixBlock result, double avg, double max) { TestUtils.assertEqualColsAndRows(expected, result); - if(expected instanceof CompressedMatrixBlock) + if(expected instanceof CompressedMatrixBlock) { + verifyContainsAllColumns((CompressedMatrixBlock) expected); expected = ((CompressedMatrixBlock) expected).decompress(); - if(result instanceof CompressedMatrixBlock) + } + if(result instanceof CompressedMatrixBlock) { + verifyContainsAllColumns((CompressedMatrixBlock) result); result = ((CompressedMatrixBlock) result).decompress(); + } TestUtils.compareMatricesPercentageDistance(expected, result, avg, max, bufferedToString); diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java index 25c8550ae6a..3941ba7463b 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java @@ -1817,7 +1817,7 @@ public void rightMultMatrixDiagonalSparseWithCols() { public void rightMultWithAllCols(MatrixBlock right) { try { - final IColIndex cols = ColIndexFactory.create(right.getNumColumns()); + final IColIndex cols = ColIndexFactory.create(right.getNumColumns()); AColGroup b = base.rightMultByMatrix(right, cols); AColGroup o = other.rightMultByMatrix(right, cols); if(!(b == null && o == null)) @@ -1933,7 +1933,7 @@ private AColGroup getColGroup(MatrixBlock mbt, CompressionType ct) { protected static AColGroup getColGroup(MatrixBlock mbt, CompressionType ct, int nRow) { try { - final IColIndex cols = ColIndexFactory.create(mbt.getNumRows()); + final IColIndex cols = ColIndexFactory.create(mbt.getNumRows()); final List es = new ArrayList<>(); final EstimationFactors f = new EstimationFactors(nRow, nRow, mbt.getSparsity()); es.add(new CompressedSizeInfoColGroup(cols, f, 321452, ct)); @@ -2005,7 +2005,7 @@ public void tsmmSelfOther() { @Test(expected = DMLCompressionException.class) public void tsmmEmpty() { - tsmmColGroup(new ColGroupEmpty( ColIndexFactory.create(new int[] {1, 3, 10}))); + tsmmColGroup(new ColGroupEmpty(ColIndexFactory.create(new int[] {1, 3, 10}))); throw new DMLCompressionException("The output is verified correct just ignore not implemented"); } @@ -2172,12 +2172,12 @@ public void sliceRows(int rl, int ru) { return; assertTrue(a.getColIndices() == base.getColIndices()); assertTrue(b.getColIndices() == other.getColIndices()); - + int nRow = ru - rl; MatrixBlock ot = sparseMB(ru - rl, maxCol); MatrixBlock bt = sparseMB(ru - rl, maxCol); decompressToSparseBlock(a, b, ot, bt, 0, nRow); - + MatrixBlock otd = denseMB(ru - rl, maxCol); MatrixBlock btd = denseMB(ru - rl, maxCol); decompressToDenseBlock(otd, btd, a, b, 0, nRow); @@ -2210,9 +2210,19 @@ private void expectNotImplementedSlice(int rl, int ru) { @Test public void getScheme() { - // create scheme and check if it compress the same matrix input in same way. - checkScheme(base.getCompressionScheme(), base, nRow, maxCol); - checkScheme(other.getCompressionScheme(), other, nRow, maxCol); + try { + // create scheme and check if it compress the same matrix input in same way. + compare(base, other); + checkScheme(base.getCompressionScheme(), base, nRow, maxCol); + checkScheme(other.getCompressionScheme(), other, nRow, maxCol); + } + catch(NotImplementedException e) { + // allow it to be not implemented + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } } private static void checkScheme(ICLAScheme ia, AColGroup a, int nRow, int nCol) { @@ -2231,6 +2241,10 @@ private static void checkScheme(ICLAScheme ia, AColGroup a, int nRow, int nCol) } } + catch(NotImplementedException e) { + // allow it to be not implemented + + } catch(Exception e) { e.printStackTrace(); fail(e.getMessage()); diff --git a/src/test/java/org/apache/sysds/test/component/compress/indexes/CustomIndexTest.java b/src/test/java/org/apache/sysds/test/component/compress/indexes/CustomIndexTest.java index eb2d386a5bf..3286a3eed61 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/indexes/CustomIndexTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/indexes/CustomIndexTest.java @@ -21,7 +21,10 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import java.util.ArrayList; import java.util.Arrays; @@ -36,6 +39,7 @@ import org.apache.sysds.runtime.compress.colgroup.indexes.RangeIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.SingleIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.TwoIndex; +import org.apache.sysds.runtime.compress.colgroup.indexes.TwoRangesIndex; import org.apache.sysds.runtime.compress.utils.Util; import org.junit.Test; import org.mockito.Mockito; @@ -792,16 +796,235 @@ public void isSortedArray1() { @Test public void isSortedArray2() { - assertFalse(ColIndexFactory.createI(0, 1, 5, 3, 9).isSorted()); + assertFalse(new ArrayIndex(new int[] {0, 1, 5, 3, 9}).isSorted()); } @Test public void isSortedArray3() { - assertFalse(ColIndexFactory.createI(0, 1, 5, 9, -13).isSorted()); + assertFalse(new ArrayIndex(new int[] {0, 1, 5, 9, -13}).isSorted()); } @Test public void isSortedArray4() { - assertFalse(ColIndexFactory.createI(0, 1, 0, 1, 0).isSorted()); + assertFalse(new ArrayIndex(new int[] {0, 1, 0, 1, 0}).isSorted()); + } + + @Test + public void combine_1() { + IColIndex a = ColIndexFactory.createI(0, 1, 2, 3); + IColIndex b = ColIndexFactory.createI(4, 5, 6, 7); + IColIndex e = ColIndexFactory.createI(0, 1, 2, 3, 4, 5, 6, 7); + assertEquals(e, a.combine(b)); + } + + @Test + public void sortArray() { + IColIndex a = new ArrayIndex(new int[] {6, 7, 3, 2, 8}); + IColIndex b = a.sort(); + IColIndex e = ColIndexFactory.createI(2, 3, 6, 7, 8); + assertFalse(a.isSorted()); + assertTrue(b.isSorted()); + assertNotEquals(e, a); + assertEquals(e, b); + } + + @Test + public void sortArray2() { + IColIndex a = new ArrayIndex(new int[] {6, 7, 3, 2, 8}); + IColIndex b = a.sort(); + IColIndex e = ColIndexFactory.createI(2, 3, 6, 7, 8); + assertFalse(a.isSorted()); + assertTrue(b.isSorted()); + assertNotEquals(e, a); + assertEquals(e, b); + } + + @Test + public void getReorderingIndex() { + IColIndex a = new ArrayIndex(new int[] {6, 4, 3, 2, 1}); + int[] b = a.getReorderingIndex(); + int[] e = new int[] {4, 3, 2, 1, 0}; + assertTrue(Arrays.equals(e, b)); + } + + @Test + public void combineToRangeFromArray() { + IColIndex a = ColIndexFactory.createI(0, 2, 4, 6, 8); + IColIndex b = ColIndexFactory.createI(1, 3, 5, 7, 9); + IColIndex e = ColIndexFactory.create(0, 10); + assertEquals(e, a.combine(b)); + } + + @Test + public void combineToRangeFromArray2() { + IColIndex a = ColIndexFactory.createI(0, 2, 4, 6, 8); + IColIndex b = ColIndexFactory.createI(1, 3, 5, 7); + IColIndex e = ColIndexFactory.create(0, 9); + assertEquals(e, a.combine(b)); + } + + @Test + public void combineToRangeFromArray3() { + IColIndex a = ColIndexFactory.createI(2, 4, 6, 8); + IColIndex b = ColIndexFactory.createI(1, 3, 5, 7); + IColIndex e = ColIndexFactory.create(1, 9); + assertEquals(e, a.combine(b)); + } + + @Test + public void combineToRangeFromArray4() { + IColIndex a = ColIndexFactory.createI(2, 4, 6, 8); + IColIndex b = ColIndexFactory.createI(1, 3, 5, 7, 9, 10, 11); + IColIndex e = ColIndexFactory.create(1, 12); + assertEquals(e, a.combine(b)); + } + + @Test + public void avgIndex() { + IColIndex a = ColIndexFactory.createI(2, 4, 6, 8); + assertEquals(5.0, a.avgOfIndex(), 0.01); + } + + @Test + public void avgIndex2() { + IColIndex a = ColIndexFactory.createI(2, 4, 6); + assertEquals(4.0, a.avgOfIndex(), 0.01); + } + + @Test + public void avgIndex3() { + IColIndex a = ColIndexFactory.createI(2, 6); + assertEquals(4.0, a.avgOfIndex(), 0.01); + } + + @Test + public void avgIndex4() { + IColIndex a = ColIndexFactory.createI(2); + assertEquals(2.0, a.avgOfIndex(), 0.01); + } + + @Test + public void avgIndex5() { + IColIndex a = ColIndexFactory.create(0, 10); + assertEquals(4.5, a.avgOfIndex(), 0.01); + } + + @Test + public void combineColGroups() { + AColGroup a = mock(AColGroup.class); + when(a.getColIndices()).thenReturn(ColIndexFactory.createI(1, 2, 5, 6)); + AColGroup b = mock(AColGroup.class); + when(b.getColIndices()).thenReturn(ColIndexFactory.createI(3, 4, 8)); + IColIndex e = ColIndexFactory.createI(1, 2, 3, 4, 5, 6, 8); + assertEquals(e, ColIndexFactory.combine(a, b)); + } + + @Test + public void combineArrayOfIndexes() { + List l = new ArrayList<>(); + l.add(ColIndexFactory.createI(1)); + l.add(ColIndexFactory.createI(3, 5)); + l.add(ColIndexFactory.createI(4, 7, 8, 9)); + l.add(ColIndexFactory.createI(10, 11, 12, 13, 14)); + + IColIndex e = ColIndexFactory.createI(1, 3, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14); + assertEquals(e, ColIndexFactory.combineIndexes(l)); + } + + @Test + public void containsAny() { + IColIndex a = ColIndexFactory.createI(27, 28, 29); + IColIndex b = ColIndexFactory.createI(61, 62, 63); + IColIndex c = a.combine(b); + assertTrue(c instanceof TwoRangesIndex); + + IColIndex d = ColIndexFactory.createI(12); + + assertFalse(c.containsAny(d)); + assertFalse(d.containsAny(c)); + } + + @Test + public void combineRanges() { + IColIndex a = ColIndexFactory.createI(1, 2, 3, 4); + IColIndex b = ColIndexFactory.createI(5, 6, 7, 8); + IColIndex e = ColIndexFactory.createI(1, 2, 3, 4, 5, 6, 7, 8); + assertEquals(e, a.combine(b)); + } + + @Test + public void combineRanges2() { + IColIndex b = ColIndexFactory.createI(1, 2, 3, 4); + IColIndex a = ColIndexFactory.createI(5, 6, 7, 8); + IColIndex e = ColIndexFactory.createI(1, 2, 3, 4, 5, 6, 7, 8); + assertEquals(e, a.combine(b)); + } + + @Test + public void combineRanges3() { + IColIndex b = ColIndexFactory.createI(1, 2, 3, 4); + IColIndex a = ColIndexFactory.createI(6, 7, 8, 9); + IColIndex e = ColIndexFactory.createI(1, 2, 3, 4, 6, 7, 8, 9); + assertEquals(e, a.combine(b)); + } + + @Test + public void combineRanges4() { + IColIndex a = ColIndexFactory.createI(1, 2, 3, 4); + IColIndex b = ColIndexFactory.createI(6, 7, 8, 9); + IColIndex e = ColIndexFactory.createI(1, 2, 3, 4, 6, 7, 8, 9); + assertEquals(e, a.combine(b)); + } + + @Test + public void containsTest() { + // to get coverage + IColIndex a = new TwoRangesIndex(new RangeIndex(1, 10), new RangeIndex(5, 10)); + assertTrue(a.contains(7)); + assertTrue(a.contains(2)); + assertTrue(a.contains(9)); + assertFalse(a.contains(-1)); + assertFalse(a.contains(11)); + assertFalse(a.contains(10)); + } + + @Test + public void containsTest2() { + // to get coverage + IColIndex a = new TwoRangesIndex(new RangeIndex(1, 4), new RangeIndex(11, 20)); + assertFalse(a.contains(7)); + assertTrue(a.contains(2)); + assertTrue(a.contains(11)); + assertFalse(a.contains(-1)); + assertFalse(a.contains(20)); + assertFalse(a.contains(10)); + } + + @Test + public void containsAnyArray1() { + IColIndex a = new TwoRangesIndex(new RangeIndex(1, 4), new RangeIndex(11, 20)); + IColIndex b = new RangeIndex(7, 15); + assertTrue(a.containsAny(b)); + } + + @Test + public void containsAnyArrayF1() { + IColIndex a = new TwoRangesIndex(new RangeIndex(1, 4), new RangeIndex(11, 20)); + IColIndex b = new RangeIndex(20, 25); + assertFalse(a.containsAny(b)); + } + + @Test + public void containsAnyArrayF2() { + IColIndex a = new TwoRangesIndex(new RangeIndex(1, 4), new RangeIndex(11, 20)); + IColIndex b = new RangeIndex(4, 11); + assertFalse(a.containsAny(b)); + } + + @Test + public void containsAnyArray2() { + IColIndex a = new TwoRangesIndex(new RangeIndex(1, 4), new RangeIndex(11, 20)); + IColIndex b = new RangeIndex(3, 11); + assertTrue(a.containsAny(b)); } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/indexes/IndexesTest.java b/src/test/java/org/apache/sysds/test/component/compress/indexes/IndexesTest.java index ed5e81c97d6..088c64798b0 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/indexes/IndexesTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/indexes/IndexesTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -34,14 +35,19 @@ import java.util.List; import java.util.Random; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.indexes.ArrayIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex.SliceResult; import org.apache.sysds.runtime.compress.colgroup.indexes.IIterate; +import org.apache.sysds.runtime.compress.colgroup.indexes.RangeIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.SingleIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.TwoIndex; +import org.apache.sysds.runtime.compress.colgroup.indexes.TwoRangesIndex; import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.utils.MemoryEstimates; import org.junit.Test; @@ -51,6 +57,7 @@ @RunWith(value = Parameterized.class) public class IndexesTest { + public static final Log LOG = LogFactory.getLog(IndexesTest.class.getName()); private final int[] expected; private final IColIndex actual; @@ -97,6 +104,26 @@ public static Collection data() { new int[] {4, 5, 6, 7, 8, 9}, // ColIndexFactory.create(4, 10)}); + tests.add(new Object[] {// + new int[] {4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, // + ColIndexFactory.create(4, 19)}); + tests.add(new Object[] {// + new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, // + ColIndexFactory.create(0, 19)}); + tests.add(new Object[] {// + new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, // + ColIndexFactory.create(1, 19)}); + + tests.add(new Object[] {// + new int[] {1, 2, 3, 4, 5, 6}, // + ColIndexFactory.create(1, 7)}); + tests.add(new Object[] {// + new int[] {2, 3, 4, 5, 6}, // + ColIndexFactory.create(2, 7)}); + tests.add(new Object[] {// + new int[] {3, 4, 5, 6}, // + ColIndexFactory.create(3, 7)}); + tests.add(createWithArray(1, 323)); tests.add(createWithArray(2, 1414)); tests.add(createWithArray(144, 32)); @@ -110,6 +137,10 @@ public static Collection data() { tests.add(createRangeWithArray(4, 132)); tests.add(createRangeWithArray(2, 132)); tests.add(createRangeWithArray(1, 132)); + tests.add(createTwoRange(1, 10, 20, 30)); + tests.add(createTwoRange(1, 10, 22, 30)); + tests.add(createTwoRange(9, 11, 22, 30)); + tests.add(createTwoRange(9, 11, 22, 60)); } catch(Exception e) { e.printStackTrace(); @@ -148,7 +179,7 @@ public void testSerialize() { compare(actual, n); } catch(IOException e) { - throw new RuntimeException("Error in io", e); + throw new RuntimeException("Error in io " + actual, e); } catch(Exception e) { e.printStackTrace(); @@ -226,7 +257,7 @@ public void slice_1() { @Test public void slice_2() { - if(expected[0] <= 1) { + if(expected[0] < 1) { SliceResult sr = actual.slice(1, expected[expected.length - 1] + 1); String errStr = actual.toString(); @@ -238,11 +269,84 @@ public void slice_2() { } } + @Test + public void slice_3() { + for(int e = 0; e < actual.size(); e++) { + + SliceResult sr = actual.slice(expected[e], expected[expected.length - 1] + 1); + String errStr = actual.toString(); + if(sr.ret != null) { + IColIndex a = sr.ret; + assertEquals(errStr, a.size(), actual.size() - e); + assertEquals(errStr, a.get(0), 0); + assertEquals(errStr, a.get(a.size() - 1), expected[expected.length - 1] - expected[e]); + } + } + } + + @Test + public void slice_4() { + SliceResult sr = actual.slice(-10, -1); + assertEquals(null, sr.ret); + assertEquals(0, sr.idEnd); + assertEquals(0, sr.idStart); + } + + @Test + public void slice_5_moreThanRange() { + SliceResult sr = actual.slice(-10, expected[expected.length - 1] + 10); + assertTrue(sr.toString() + " " + actual, sr.ret.contains(expected[0] + 10)); + assertEquals(0, sr.idStart); + } + + @Test + public void slice_5_SubRange() { + if(expected.length > 5) { + + SliceResult sr = actual.slice(4, expected[5] + 1); + + assertEquals(actual.toString(), expected[5], sr.ret.get(sr.ret.size() - 1) + 4); + } + } + @Test public void equals() { assertEquals(actual, ColIndexFactory.create(expected)); } + @Test + public void equalsSizeDiff_range() { + if(actual.size() == 10) + return; + + IColIndex a = new RangeIndex(0, 10); + assertNotEquals(actual, a); + } + + @Test + public void equalsSizeDiff_twoRanges() { + if(actual.size() == 10) + return; + + IColIndex a = new TwoRangesIndex(new RangeIndex(0, 5), new RangeIndex(6, 10)); + assertNotEquals(actual, a); + } + + @Test + public void equalsSizeDiff_twoRanges2() { + if(actual.size() == 10 + 3) + return; + RangeIndex a = new RangeIndex(1, 10); + RangeIndex b = new RangeIndex(22, 25); + TwoRangesIndex c = (TwoRangesIndex) a.combine(b); + assertNotEquals(actual, c); + } + + @Test + public void equalsItself() { + assertEquals(actual, actual); + } + @Test public void isContiguous() { boolean c = expected[expected.length - 1] - expected[0] + 1 == expected.length; @@ -300,7 +404,8 @@ public void combineTwoBellow() { @Test public void hashCodeEquals() { - assertEquals(actual.hashCode(), ColIndexFactory.create(expected).hashCode()); + if(!(actual instanceof TwoRangesIndex)) + assertEquals(actual.hashCode(), ColIndexFactory.create(expected).hashCode()); } @Test @@ -308,6 +413,176 @@ public void estimateInMemorySizeIsNotToBig() { assertTrue(MemoryEstimates.intArrayCost(expected.length) >= actual.estimateInMemorySize() - 16); } + @Test + public void containsInt1() { + assertTrue(actual.contains(expected[0])); + } + + @Test + public void containsInt2() { + assertTrue(actual.contains(expected[expected.length - 1])); + } + + @Test + public void containsIntAllElements() { + for(int i = 0; i < expected.length; i++) + assertTrue(actual.contains(expected[i])); + } + + @Test + public void containsIntNot1() { + assertFalse(actual.contains(expected[expected.length - 1] + 3)); + } + + @Test + public void containsIntNot2() { + assertFalse(actual.toString(), actual.contains(expected[0] - 1)); + } + + @Test + public void containsIntNotAllInbetween() { + int j = 0; + for(int i = expected[0]; i < expected[expected.length - 1]; i++) { + if(i == expected[j]) { + j++; + assertTrue(actual.toString(), actual.contains(i)); + } + else { + assertFalse(actual.toString(), actual.contains(i)); + } + } + } + + @Test + public void containsAnySingle() { + assertTrue(actual.containsAny(new SingleIndex(expected[expected.length - 1]))); + } + + @Test + public void containsAnySingleFalse1() { + assertFalse(actual.containsAny(new SingleIndex(expected[expected.length - 1] + 1))); + } + + @Test + public void containsAnySingleFalse2() { + assertFalse(actual.containsAny(new SingleIndex(expected[0] - 1))); + } + + @Test + public void containsAnyTwo() { + assertTrue(actual.containsAny(new TwoIndex(expected[expected.length - 1], expected[expected.length - 1] + 4))); + } + + @Test + public void containsAnyTwoFalse() { + assertFalse( + actual.containsAny(new TwoIndex(expected[expected.length - 1] + 1, expected[expected.length - 1] + 4))); + } + + @Test + public void iteratorsV() { + IIterate i = actual.iterator(); + while(i.hasNext()) { + int v = i.v(); + assertEquals(actual.toString(), v, i.next()); + } + } + + @Test + public void averageOfIndex() { + double a = actual.avgOfIndex(); + double s = 0.0; + for(int i = 0; i < expected.length; i++) + s += expected[i]; + + assertEquals(actual.toString(), s / expected.length, a, 0.0000001); + } + + @Test + public void isSorted() { + assertTrue(actual.isSorted()); + } + + @Test + public void sort() { + assertTrue(actual.isSorted()); + try { + + actual.sort();// should do nothing + } + catch(DMLCompressionException e) { + // okay + } + assertTrue(actual.isSorted()); + } + + @Test + public void getReorderingIndex() { + try { + + int[] ro = actual.getReorderingIndex(); + if(ro != null) { + for(int i = 0; i < ro.length - 1; i++) { + assertTrue(ro[i] < ro[i + 1]); + } + } + } + catch(DMLCompressionException e) { + // okay + } + } + + @Test + public void findIndexBefore() { + final String er = actual.toString(); + assertEquals(er, -1, actual.findIndex(expected[0] - 1)); + assertEquals(er, -1, actual.findIndex(expected[0] - 10)); + assertEquals(er, -1, actual.findIndex(expected[0] - 100)); + } + + @Test + public void findIndexAll() { + final String er = actual.toString(); + for(int i = 0; i < expected.length; i++) { + assertEquals(er, i, actual.findIndex(expected[i])); + } + } + + @Test + public void findIndexAllMinus1() { + final String er = actual.toString(); + for(int i = 1; i < expected.length; i++) { + if(expected[i - 1] == expected[i] - 1) { + assertEquals(er, i - 1, actual.findIndex(expected[i] - 1)); + } + else { + assertEquals(er, i * -1 - 1, actual.findIndex(expected[i] - 1)); + + } + } + } + + @Test + public void findIndexAfter() { + final int el = expected.length; + final String er = actual.toString(); + assertEquals(er, -el - 1, actual.findIndex(expected[el - 1] + 1)); + assertEquals(er, -el - 1, actual.findIndex(expected[el - 1] + 10)); + assertEquals(er, -el - 1, actual.findIndex(expected[el - 1] + 100)); + } + + @Test + public void testHash() { + // flawed test in the case hashes can collide, but it should be unlikely. + IColIndex a = ColIndexFactory.createI(1, 2, 3, 1342); + if(a.equals(actual)) { + assertEquals(a.hashCode(), actual.hashCode()); + } + else { + assertNotEquals(a.hashCode(), actual.hashCode()); + } + } + private void shift(int i) { compare(expected, actual.shift(i), i); } @@ -331,6 +606,7 @@ private static void compare(IColIndex expected, IColIndex actual) { } private static void compare(int[] expected, IIterate actual) { + // LOG.error(expected); for(int i = 0; i < expected.length; i++) { assertTrue(actual.hasNext()); assertEquals(i, actual.i()); @@ -379,4 +655,16 @@ private static Object[] createRangeWithArray(int size, int seed) { else throw new DMLRuntimeException("Invalid construction of range array"); } + + private static Object[] createTwoRange(int l1, int u1, int l2, int u2) { + RangeIndex a = new RangeIndex(l1, u1); + RangeIndex b = new RangeIndex(l2, u2); + TwoRangesIndex c = (TwoRangesIndex) a.combine(b); + int[] exp = new int[u1 - l1 + u2 - l2]; + for(int i = l1, j = 0; i < u1; i++, j++) + exp[j] = i; + for(int i = l2, j = u1 - l1; i < u2; i++, j++) + exp[j] = i; + return new Object[] {exp, c}; + } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/indexes/NegativeIndexTest.java b/src/test/java/org/apache/sysds/test/component/compress/indexes/NegativeIndexTest.java index 69f006ff799..785608598a4 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/indexes/NegativeIndexTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/indexes/NegativeIndexTest.java @@ -247,24 +247,24 @@ public void hashCode7() { @Test public void hashCode8() { assertTrue( - new RangeIndex(0, 10).hashCode() == new ArrayIndex(new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}).hashCode()); + new RangeIndex(0, 10).hashCode() != new ArrayIndex(new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}).hashCode()); } @Test public void hashCode9() { - assertTrue(new RangeIndex(0, 4).hashCode() == new ArrayIndex(new int[] {0, 1, 2, 3,}).hashCode()); + assertTrue(new RangeIndex(0, 4).hashCode() != new ArrayIndex(new int[] {0, 1, 2, 3,}).hashCode()); } @Test public void hashCode10() { assertTrue( - new RangeIndex(5555, 5560).hashCode() == new ArrayIndex(new int[] {5555, 5556, 5557, 5558, 5559}).hashCode()); + new RangeIndex(5555, 5560).hashCode() != new ArrayIndex(new int[] {5555, 5556, 5557, 5558, 5559}).hashCode()); } @Test public void hashCode11() { assertTrue(new RangeIndex(5000000, 5000005) - .hashCode() == new ArrayIndex(new int[] {5000000, 5000001, 5000002, 5000003, 5000004}).hashCode()); + .hashCode() != new ArrayIndex(new int[] {5000000, 5000001, 5000002, 5000003, 5000004}).hashCode()); } private static Object notRelated() { @@ -276,12 +276,12 @@ public void invalidCreate1() { ColIndexFactory.create(new int[0]); } - @Test(expected = NullPointerException.class) + @Test(expected = DMLRuntimeException.class) public void invalidCreate2() { ColIndexFactory.create(new IntArrayList()); } - @Test(expected = ArrayIndexOutOfBoundsException.class) + @Test(expected = DMLRuntimeException.class) public void invalidCreate3() { ColIndexFactory.create(new IntArrayList(0)); } @@ -320,4 +320,29 @@ public void invalidCreate8() { public void invalidCreate9() { ColIndexFactory.create(-10); } + + @Test(expected = DMLCompressionException.class) + public void invalidRange() { + new RangeIndex(10, 4); + } + + @Test(expected = DMLCompressionException.class) + public void invalidRange2() { + new RangeIndex(10, 10); + } + + @Test(expected = DMLCompressionException.class) + public void invalidRange3() { + ColIndexFactory.createI(0, -1, 2); + } + + @Test(expected = DMLCompressionException.class) + public void invalidRange4() { + ColIndexFactory.createI(0, 0, 2); + } + + @Test(expected = DMLCompressionException.class) + public void invalidRange5() { + ColIndexFactory.createI(0, 1, 1); + } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/lib/CombineGroupsTest.java b/src/test/java/org/apache/sysds/test/component/compress/lib/CombineGroupsTest.java index 120c5d2ed58..f118705c28d 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/lib/CombineGroupsTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/lib/CombineGroupsTest.java @@ -32,7 +32,7 @@ import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; import org.apache.sysds.runtime.compress.colgroup.AColGroup; -import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; +import org.apache.sysds.runtime.compress.colgroup.indexes.ArrayIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.IIterate; import org.apache.sysds.runtime.compress.lib.CLALibCombineGroups; @@ -389,7 +389,7 @@ private static AColGroup moveCols(AColGroup g, int[] mix) { it.next(); } - g = g.copyAndSet(ColIndexFactory.create(newIndexes)); + g = g.copyAndSet(new ArrayIndex(newIndexes)); g = g.sortColumnIndexes(); return g; } diff --git a/src/test/java/org/apache/sysds/test/component/misc/ThreadPool.java b/src/test/java/org/apache/sysds/test/component/misc/ThreadPool.java index ca79e8800b3..af6de205d90 100644 --- a/src/test/java/org/apache/sysds/test/component/misc/ThreadPool.java +++ b/src/test/java/org/apache/sysds/test/component/misc/ThreadPool.java @@ -370,39 +370,28 @@ public void coverEdge() { public void invokeAndShutdownException() throws InterruptedException { ExecutorService p = mock(ExecutorService.class); ExecutorService c = new CommonThreadPool(p); - when(p.invokeAll(null)).thenThrow(new RuntimeException("Test")); - - CommonThreadPool.invokeAndShutdown(p, null); - + CommonThreadPool.invokeAndShutdown(c, null); } @Test public void invokeAndShutdown() throws InterruptedException { - ExecutorService p = mock(ExecutorService.class); ExecutorService c = new CommonThreadPool(p); - Collection> cc = (Collection>) null; when(p.invokeAll(cc)).thenReturn(new ArrayList>()); - CommonThreadPool.invokeAndShutdown(c, null); - } @Test @SuppressWarnings("all") - public void invokeAndShutdownV2() throws InterruptedException{ - + public void invokeAndShutdownV2() throws InterruptedException { ExecutorService p = mock(ExecutorService.class); ExecutorService c = new CommonThreadPool(p); - Collection> cc = (Collection>) null; List> f = new ArrayList>(); f.add(mock(Future.class)); - when(p.invokeAll(cc)).thenReturn(f ); - + when(p.invokeAll(cc)).thenReturn(f); CommonThreadPool.invokeAndShutdown(c, null); - } }