Skip to content

Commit

Permalink
[SYSTEMDS-3608] Cocode shortcut
Browse files Browse the repository at this point in the history
This commit adds a shortcut in que based cocode algorithm that
indicate if no cocoding have happened the last 5 tries in the que
abort cocoding. Previously cases where cocoding was on the border
then all pairs of columns would be tried single threaded. To avoid this
the implementation is now avoiding the all pairing if no cocoding is
happening.

The cost optimization of cocoding have been tuned to add a minor delta
on the cost based on the average of the column indexes in the groups.
This makes the que based optimizer likely to cocode columns next to each
other if they have similar cost, making the cocoding algorithm faster
in cases with high corelation between columns next to each other.

This commit also adds an TwoRangesIndex, that is useful in the case
where almost all columns get cocoded, and the greedy algorithm combine
colgroups with thousands of groups. In this case the column arrays
combining dominated.

Also added in this commit is compression schemes for SDC and refinements
for SDC, and a natural update and upgrade scheme progression from initially
empty columns to const to DDC. Next up is to add a transition to SDC
in case the distribution of values become dominated by specific values.

We also add a minor update to AggTernaryOp for compressed with a shortcut
that avoids an decompression in L2SVM.

Closes #1874
  • Loading branch information
Baunsgaard committed Aug 8, 2023
1 parent 68c2c17 commit 8dbfc23
Show file tree
Hide file tree
Showing 62 changed files with 2,457 additions and 300 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseRow;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
Expand All @@ -77,7 +78,6 @@
import org.apache.sysds.runtime.matrix.data.RandomMatrixGenerator;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
Expand Down Expand Up @@ -587,11 +587,21 @@ else if(isOverlapping()) {

@Override
public MatrixBlock reorgOperations(ReorgOperator op, MatrixValue ret, int startRow, int startColumn, int length) {
// Allow transpose to be compressed output. In general we need to have a transposed flag on
// the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025
printDecompressWarning(op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName());
MatrixBlock tmp = decompress(op.getNumThreads());
return tmp.reorgOperations(op, ret, startRow, startColumn, length);
if(op.fn instanceof SwapIndex && this.getNumColumns() == 1) {
MatrixBlock tmp = decompress(op.getNumThreads());
long nz = tmp.setNonZeros(tmp.getNonZeros());
tmp = new MatrixBlock(tmp.getNumColumns(), tmp.getNumRows(), tmp.getDenseBlockValues());
tmp.setNonZeros(nz);
return tmp;
}
else {
// Allow transpose to be compressed output. In general we need to have a transposed flag on
// the compressed matrix. https://issues.apache.org/jira/browse/SYSTEMDS-3025
String message = op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName();
MatrixBlock tmp = getUncompressed(message, op.getNumThreads());
return tmp.reorgOperations(op, ret, startRow, startColumn, length);
}

}

public boolean isOverlapping() {
Expand Down Expand Up @@ -788,24 +798,6 @@ public MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result) {
return getUncompressed("sortOperations").sortOperations(right, result);
}

@Override
public MatrixBlock aggregateTernaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock m3, MatrixBlock ret,
AggregateTernaryOperator op, boolean inCP) {
boolean m1C = m1 instanceof CompressedMatrixBlock;
boolean m2C = m2 instanceof CompressedMatrixBlock;
boolean m3C = m3 instanceof CompressedMatrixBlock;
printDecompressWarning("aggregateTernaryOperations " + op.aggOp.getClass().getSimpleName() + " "
+ op.indexFn.getClass().getSimpleName() + " " + op.aggOp.increOp.fn.getClass().getSimpleName() + " "
+ op.binaryFn.getClass().getSimpleName() + " m1,m2,m3 " + m1C + " " + m2C + " " + m3C);
MatrixBlock left = getUncompressed(m1);
MatrixBlock right1 = getUncompressed(m2);
MatrixBlock right2 = getUncompressed(m3);
ret = left.aggregateTernaryOperations(left, right1, right2, ret, op, inCP);
if(ret.getNumRows() == 0 || ret.getNumColumns() == 0)
throw new DMLCompressionException("Invalid output");
return ret;
}

@Override
public MatrixBlock uaggouterchainOperations(MatrixBlock mbLeft, MatrixBlock mbRight, MatrixBlock mbOut,
BinaryOperator bOp, AggregateUnaryOperator uaggOp) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,10 @@ public static CompressedMatrixBlock createConstant(int numRows, int numCols, dou
}

private Pair<MatrixBlock, CompressionStatistics> compressMatrix() {
if(mb.getNonZeros() < 0)
throw new DMLCompressionException("Invalid to compress matrices with unknown nonZeros");
if(mb.getNonZeros() < 0) {
LOG.warn("Recomputing non-zeros since it is unknown in compression");
mb.recomputeNonZeros();
}
else if(mb instanceof CompressedMatrixBlock && ((CompressedMatrixBlock) mb).isOverlapping()) {
LOG.warn("Unsupported recompression of overlapping compression");
return new ImmutablePair<>(mb, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ public class CoCodePriorityQue extends AColumnCoCoder {

private static final int COL_COMBINE_THRESHOLD = 1024;

protected CoCodePriorityQue(AComEst sizeEstimator, ACostEstimate costEstimator,
CompressionSettings cs) {
protected CoCodePriorityQue(AComEst sizeEstimator, ACostEstimate costEstimator, CompressionSettings cs) {
super(sizeEstimator, costEstimator, cs);
}

Expand All @@ -59,8 +58,8 @@ protected CompressedSizeInfo coCodeColumns(CompressedSizeInfo colInfos, int k) {
return colInfos;
}

protected static List<CompressedSizeInfoColGroup> join(List<CompressedSizeInfoColGroup> groups,
AComEst sEst, ACostEstimate cEst, int minNumGroups, int k) {
protected static List<CompressedSizeInfoColGroup> join(List<CompressedSizeInfoColGroup> groups, AComEst sEst,
ACostEstimate cEst, int minNumGroups, int k) {

if(groups.size() > COL_COMBINE_THRESHOLD && k > 1)
return combineMultiThreaded(groups, sEst, cEst, minNumGroups, k);
Expand Down Expand Up @@ -111,37 +110,53 @@ private static List<CompressedSizeInfoColGroup> combineBlock(List<CompressedSize
return combineBlock(que, sEst, cEst, minNumGroups);
}

private static List<CompressedSizeInfoColGroup> combineBlock(Queue<CompressedSizeInfoColGroup> que,
AComEst sEst, ACostEstimate cEst, int minNumGroups) {
private static List<CompressedSizeInfoColGroup> combineBlock(Queue<CompressedSizeInfoColGroup> que, AComEst sEst,
ACostEstimate cEst, int minNumGroups) {

List<CompressedSizeInfoColGroup> ret = new ArrayList<>();
CompressedSizeInfoColGroup l = null;
l = que.poll();
int groupNr = ret.size() + que.size();
while(que.peek() != null && groupNr >= minNumGroups) {
int lastCombine = 0; // if we have not combined in the last 5 tries abort cocoding.

while(que.peek() != null && groupNr >= minNumGroups && lastCombine < 5) {
CompressedSizeInfoColGroup r = que.peek();
CompressedSizeInfoColGroup g = sEst.combine(l, r);

if(g != null) {
double costOfJoin = cEst.getCost(g);
double costIndividual = cEst.getCost(l) + cEst.getCost(r);

if(costOfJoin < costIndividual) {
que.poll();
int numColumns = g.getColumns().size();
if(numColumns > 128)
if(numColumns > 128){
lastCombine++;
ret.add(g);
else
}
else{
lastCombine = 0;
que.add(g);
}
}
else
else{
lastCombine++;
ret.add(l);
}
}
else
else{
lastCombine++;
ret.add(l);
}

l = que.poll();
groupNr = ret.size() + que.size();
}
while(que.peek() != null){
// empty que
ret.add(l);
l = que.poll();
}

if(l != null)
ret.add(l);
Expand All @@ -153,11 +168,15 @@ private static List<CompressedSizeInfoColGroup> combineBlock(Queue<CompressedSiz
}

private static Queue<CompressedSizeInfoColGroup> getQue(int size, ACostEstimate cEst) {
Comparator<CompressedSizeInfoColGroup> comp = Comparator.comparing(x -> cEst.getCost(x));
Comparator<CompressedSizeInfoColGroup> comp = Comparator.comparing(x -> getCost(x, cEst));
Queue<CompressedSizeInfoColGroup> que = new PriorityQueue<>(size, comp);
return que;
}

private static double getCost(CompressedSizeInfoColGroup x, ACostEstimate cEst) {
return cEst.getCost(x) + x.getColumns().avgOfIndex() / 100000;
}

protected static class PQTask implements Callable<List<CompressedSizeInfoColGroup>> {

private final List<CompressedSizeInfoColGroup> _groups;
Expand All @@ -167,8 +186,8 @@ protected static class PQTask implements Callable<List<CompressedSizeInfoColGrou
private final ACostEstimate _cEst;
private final int _minNumGroups;

protected PQTask(List<CompressedSizeInfoColGroup> groups, int start, int end, AComEst sEst,
ACostEstimate cEst, int minNumGroups) {
protected PQTask(List<CompressedSizeInfoColGroup> groups, int start, int end, AComEst sEst, ACostEstimate cEst,
int minNumGroups) {
_groups = groups;
_start = start;
_end = end;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,19 @@
import java.util.ArrayList;
import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.cost.ACostEstimate;
import org.apache.sysds.runtime.compress.estim.AComEst;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.compress.utils.IntArrayList;

public interface CoCoderFactory {
public static final Log LOG = LogFactory.getLog(AColumnCoCoder.class.getName());

/**
* The Valid coCoding techniques
Expand All @@ -53,51 +56,69 @@ public enum PartitionerType {
* @param cs The compression settings used in the compression.
* @return The estimated (hopefully) best groups of ColGroups.
*/
public static CompressedSizeInfo findCoCodesByPartitioning(AComEst est, CompressedSizeInfo colInfos,
int k, ACostEstimate costEstimator, CompressionSettings cs) {
public static CompressedSizeInfo findCoCodesByPartitioning(AComEst est, CompressedSizeInfo colInfos, int k,
ACostEstimate costEstimator, CompressionSettings cs) {

// Use column group partitioner to create partitions of columns
AColumnCoCoder co = createColumnGroupPartitioner(cs.columnPartitioner, est, costEstimator, cs);

// Find out if any of the groups are empty.
boolean containsEmpty = false;
for(CompressedSizeInfoColGroup g : colInfos.compressionInfo) {
if(g.isEmpty()) {
containsEmpty = true;
break;
}
}
final boolean containsEmptyOrConst = containsEmptyOrConst(colInfos);

// if there are no empty columns then try cocode algorithms for all columns
if(!containsEmpty)
// if there are no empty or const columns then try cocode algorithms for all columns
if(!containsEmptyOrConst)
return co.coCodeColumns(colInfos, k);
else {
// filtered empty groups
final List<IColIndex> emptyCols = new ArrayList<>();
// filtered const groups
final List<IColIndex> constCols = new ArrayList<>();
// filtered groups -- in the end starting with all groups
final List<CompressedSizeInfoColGroup> groups = new ArrayList<>();

final int nRow = colInfos.compressionInfo.get(0).getNumRows();

// filter groups
for(int i = 0; i < colInfos.compressionInfo.size(); i++) {
CompressedSizeInfoColGroup g = colInfos.compressionInfo.get(i);
if(g.isEmpty())
emptyCols.add(g.getColumns());
else if(g.isConst())
constCols.add(g.getColumns());
else
groups.add(g);
}

// extract all empty columns
IntArrayList emptyCols = new IntArrayList();
List<CompressedSizeInfoColGroup> notEmpty = new ArrayList<>();
// overwrite groups.
colInfos.compressionInfo = groups;

// cocode remaining groups
if(!groups.isEmpty()) {
colInfos = co.coCodeColumns(colInfos, k);
}

for(CompressedSizeInfoColGroup g : colInfos.compressionInfo) {
if(g.isEmpty())
emptyCols.appendValue(g.getColumns().get(0));
else
notEmpty.add(g);
}
// add empty
if(emptyCols.size() > 0) {
final IColIndex idx = ColIndexFactory.combineIndexes(emptyCols);
colInfos.compressionInfo.add(new CompressedSizeInfoColGroup(idx, nRow, CompressionType.EMPTY));
}

final int nRow = colInfos.compressionInfo.get(0).getNumRows();
// add const
if(constCols.size() > 0) {
final IColIndex idx = ColIndexFactory.combineIndexes(constCols);
colInfos.compressionInfo.add(new CompressedSizeInfoColGroup(idx, nRow, CompressionType.CONST));
}

final IColIndex idx = ColIndexFactory.create(emptyCols);
if(notEmpty.isEmpty()) { // if all empty (unlikely but could happen)
CompressedSizeInfoColGroup empty = new CompressedSizeInfoColGroup(idx, nRow);
return new CompressedSizeInfo(empty);
}
return colInfos;

// cocode all not empty columns
colInfos.compressionInfo = notEmpty;
colInfos = co.coCodeColumns(colInfos, k);
}
}

// add empty columns back as single columns
colInfos.compressionInfo.add(new CompressedSizeInfoColGroup(idx, nRow));
return colInfos;
private static boolean containsEmptyOrConst(CompressedSizeInfo colInfos) {
for(CompressedSizeInfoColGroup g : colInfos.compressionInfo)
if(g.isEmpty() || g.isConst())
return true;
return false;
}

private static AColumnCoCoder createColumnGroupPartitioner(PartitionerType type, AComEst est,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@ public boolean contains(ColIndexes a, ColIndexes b) {

if(a == null || b == null)
return false;
int id = _indexes.findIndex(a._indexes.get(0));
if(id >= 0)
return true;
id = _indexes.findIndex(b._indexes.get(0));
return id >= 0;
return _indexes.contains(a._indexes.get(0)) || _indexes.contains(b._indexes.get(0));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.offset.AOffset;
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
import org.apache.sysds.runtime.compress.colgroup.scheme.SDCScheme;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.compress.estim.EstimationFactors;

Expand Down Expand Up @@ -62,4 +64,9 @@ public final CompressedSizeInfoColGroup getCompressionInfo(int nRow) {
EstimationFactors ef = new EstimationFactors(getNumValues(), _numRows, getNumberOffsets(), _dict.getSparsity());
return new CompressedSizeInfoColGroup(_colIndexes, ef, nRow, getCompType(),getEncoding());
}

@Override
public ICLAScheme getCompressionScheme() {
return SDCScheme.create(this);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.offset.AIterator;
import org.apache.sysds.runtime.compress.colgroup.offset.AOffset;
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
import org.apache.sysds.runtime.compress.colgroup.scheme.SDCScheme;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.compress.estim.EstimationFactors;
import org.apache.sysds.runtime.data.DenseBlock;
Expand Down Expand Up @@ -230,4 +232,9 @@ public final CompressedSizeInfoColGroup getCompressionInfo(int nRow) {
EstimationFactors ef = new EstimationFactors(getNumValues(), _numRows, getNumberOffsets(), _dict.getSparsity());
return new CompressedSizeInfoColGroup(_colIndexes, ef, nRow, getCompType(),getEncoding());
}

@Override
public ICLAScheme getCompressionScheme() {
return SDCScheme.create(this);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -459,12 +459,12 @@ public AColGroup appendNInternal(AColGroup[] g) {

@Override
public ICLAScheme getCompressionScheme() {
return null;
throw new NotImplementedException();
}

@Override
public AColGroup recompress() {
return this;
throw new NotImplementedException();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,8 @@ public AColGroup recompress() {
@Override
public CompressedSizeInfoColGroup getCompressionInfo(int nRow) {
EstimationFactors ef = new EstimationFactors(getNumValues(), 1, 0, 0.0);
return new CompressedSizeInfoColGroup(_colIndexes, ef, estimateInMemorySize(), CompressionType.CONST, getEncoding());
return new CompressedSizeInfoColGroup(_colIndexes, ef, estimateInMemorySize(), CompressionType.EMPTY,
getEncoding());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ public AColGroup appendNInternal(AColGroup[] g) {

@Override
public ICLAScheme getCompressionScheme() {
return null;
throw new NotImplementedException();
}

@Override
Expand Down
Loading

0 comments on commit 8dbfc23

Please sign in to comment.