Skip to content

Commit

Permalink
IDictionary
Browse files Browse the repository at this point in the history
  • Loading branch information
Baunsgaard committed Aug 11, 2023
1 parent 11c25b8 commit 1a51720
Show file tree
Hide file tree
Showing 38 changed files with 1,814 additions and 1,132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ protected CompressedSizeInfo coCodeColumns(CompressedSizeInfo colInfos, int k) {
final int startSize = colInfos.getInfo().size();
if(startSize == 1)
return colInfos; // nothing to join when there only is one column
else if(startSize <= 5) {// Greedy all compare all if small number of columns
else if(startSize <= 16) {// Greedy all compare all if small number of columns
LOG.debug("Hybrid chose to do greedy cocode because of few columns");
CoCodeGreedy gd = new CoCodeGreedy(_sest, _cest, _cs);
return colInfos.setInfo(gd.combine(colInfos.getInfo(), k));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.DMLScriptException;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.data.SparseBlock;
Expand Down Expand Up @@ -205,7 +205,7 @@ public final void tsmm(MatrixBlock ret, int nRows) {

protected abstract void tsmm(double[] result, int numColumns, int nRows);

protected static void tsmm(double[] result, int numColumns, int[] counts, ADictionary dict, IColIndex colIndexes) {
protected static void tsmm(double[] result, int numColumns, int[] counts, IDictionary dict, IColIndex colIndexes) {
dict = dict.getMBDict(colIndexes.size());
final MatrixBlock mb = ((MatrixBlockDictionary) dict).getMatrixBlock();
if(mb.isInSparseFormat())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import java.io.DataOutput;
import java.io.IOException;

import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.utils.MemoryEstimates;

Expand All @@ -46,7 +46,7 @@ public abstract class AColGroupOffset extends APreAgg {
/** If the column group contains unassigned rows. */
protected final boolean _zeros;

protected AColGroupOffset(IColIndex colIndices, int numRows, boolean zeros, ADictionary dict, int[] ptr, char[] data, int[] cachedCounts) {
protected AColGroupOffset(IColIndex colIndices, int numRows, boolean zeros, IDictionary dict, int[] ptr, char[] data, int[] cachedCounts) {
super(colIndices, dict, cachedCounts);
_numRows = numRows;
_zeros = zeros;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
import java.util.HashSet;
import java.util.Set;

import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
Expand All @@ -37,23 +38,23 @@
public abstract class ADictBasedColGroup extends AColGroupCompressed implements IContainADictionary {
private static final long serialVersionUID = -3737025296618703668L;
/** Distinct value tuples associated with individual bitmaps. */
protected final ADictionary _dict;
protected final IDictionary _dict;

/**
* A Abstract class for column groups that contain ADictionary for values.
* A Abstract class for column groups that contain IDictionary for values.
*
* @param colIndices The Column indexes
* @param dict The dictionary to contain the distinct tuples
*/
protected ADictBasedColGroup(IColIndex colIndices, ADictionary dict) {
protected ADictBasedColGroup(IColIndex colIndices, IDictionary dict) {
super(colIndices);
_dict = dict;
if(dict == null)
throw new NullPointerException("null dict is invalid");

}

public ADictionary getDictionary() {
public IDictionary getDictionary() {
return _dict;
}

Expand Down Expand Up @@ -197,14 +198,14 @@ public final AColGroup rightMultByMatrix(MatrixBlock right, IColIndex allCols) {
return null;

final int nVals = getNumValues();
final ADictionary preAgg = (right.isInSparseFormat()) ? // Chose Sparse or Dense
final IDictionary preAgg = (right.isInSparseFormat()) ? // Chose Sparse or Dense
rightMMPreAggSparse(nVals, right.getSparseBlock(), agCols, 0, nCol) : // sparse
_dict.preaggValuesFromDense(nVals, _colIndexes, agCols, right.getDenseBlockValues(), nCol); // dense
return allocateRightMultiplication(right, agCols, preAgg);
}

protected abstract AColGroup allocateRightMultiplication(MatrixBlock right, IColIndex colIndexes,
ADictionary preAgg);
IDictionary preAgg);

/**
* Find the minimum number of columns that are effected by the right multiplication
Expand Down Expand Up @@ -269,7 +270,7 @@ protected IColIndex rightMMGetColsSparse(SparseBlock b, int retCols, IColIndex a
return ColIndexFactory.create(aggregateColumns);
}

private ADictionary rightMMPreAggSparse(int numVals, SparseBlock b, IColIndex aggregateColumns, int cl, int cu) {
private IDictionary rightMMPreAggSparse(int numVals, SparseBlock b, IColIndex aggregateColumns, int cl, int cu) {
final double[] ret = new double[numVals * aggregateColumns.size()];
for(int h = 0; h < _colIndexes.size(); h++) {
final int colIdx = _colIndexes.get(h);
Expand Down Expand Up @@ -300,10 +301,10 @@ public final AColGroup copyAndSet(IColIndex colIndexes) {
return copyAndSet(colIndexes, _dict);
}

protected final AColGroup copyAndSet(ADictionary newDictionary) {
protected final AColGroup copyAndSet(IDictionary newDictionary) {
return copyAndSet(_colIndexes, newDictionary);
}

protected abstract AColGroup copyAndSet(IColIndex colIndexes, ADictionary newDictionary);
protected abstract AColGroup copyAndSet(IColIndex colIndexes, IDictionary newDictionary);

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
package org.apache.sysds.runtime.compress.colgroup;

import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy;
import org.apache.sysds.runtime.compress.lib.CLALibTSMM;
Expand All @@ -37,14 +37,14 @@ public abstract class AMorphingMMColGroup extends AColGroupValue {
private static final long serialVersionUID = -4265713396790607199L;

/**
* A Abstract class for column groups that contain ADictionary for values.
* A Abstract class for column groups that contain IDictionary for values.
*
* @param colIndices The Column indexes
* @param dict The dictionary to contain the distinct tuples
* @param cachedCounts The cached counts of the distinct tuples (can be null since it should be possible to
* reconstruct the counts on demand)
*/
protected AMorphingMMColGroup(IColIndex colIndices, ADictionary dict, int[] cachedCounts) {
protected AMorphingMMColGroup(IColIndex colIndices, IDictionary dict, int[] cachedCounts) {
super(colIndices, dict, cachedCounts);
}

Expand Down Expand Up @@ -161,7 +161,7 @@ protected IColIndex rightMMGetColsSparse(SparseBlock b, int nCols, IColIndex all
}

@Override
protected AColGroup allocateRightMultiplication(MatrixBlock right, IColIndex colIndexes, ADictionary preAgg) {
protected AColGroup allocateRightMultiplication(MatrixBlock right, IColIndex colIndexes, IDictionary preAgg) {
LOG.warn("right mm should not be called directly on a morphing column group");
final double[] common = getCommon();
final int rc = right.getNumColumns();
Expand Down Expand Up @@ -195,7 +195,7 @@ protected AColGroup allocateRightMultiplication(MatrixBlock right, IColIndex col
}

protected abstract AColGroup allocateRightMultiplicationCommon(double[] common, IColIndex colIndexes,
ADictionary preAgg);
IDictionary preAgg);

/**
* extract common value from group and return non morphing group
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

package org.apache.sysds.runtime.compress.colgroup;

import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.offset.AOffset;
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
Expand All @@ -42,7 +42,7 @@ public abstract class ASDC extends AMorphingMMColGroup implements AOffsetsGroup
/** The number of rows in this column group */
protected final int _numRows;

protected ASDC(IColIndex colIndices, int numRows, ADictionary dict, AOffset offsets, int[] cachedCounts) {
protected ASDC(IColIndex colIndices, int numRows, IDictionary dict, AOffset offsets, int[] cachedCounts) {
super(colIndices, dict, cachedCounts);
_indexes = offsets;
_numRows = numRows;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

package org.apache.sysds.runtime.compress.colgroup;

import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.offset.AIterator;
Expand All @@ -40,7 +40,7 @@ public abstract class ASDCZero extends APreAgg implements AOffsetsGroup, IContai
/** The number of rows in this column group */
protected final int _numRows;

protected ASDCZero(IColIndex colIndices, int numRows, ADictionary dict, AOffset offsets, int[] cachedCounts) {
protected ASDCZero(IColIndex colIndices, int numRows, IDictionary dict, AOffset offsets, int[] cachedCounts) {
super(colIndices, dict, cachedCounts);
_indexes = offsets;
_numRows = numRows;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
import java.io.IOException;

import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
Expand Down Expand Up @@ -58,7 +58,7 @@ public class ColGroupConst extends ADictBasedColGroup implements IContainDefault
* @param colIndices The Colum indexes for the column group.
* @param dict The dictionary containing one tuple for the entire compression.
*/
private ColGroupConst(IColIndex colIndices, ADictionary dict) {
private ColGroupConst(IColIndex colIndices, IDictionary dict) {
super(colIndices, dict);
}

Expand All @@ -70,7 +70,7 @@ private ColGroupConst(IColIndex colIndices, ADictionary dict) {
* @param dict The dictionary to use
* @return A Colgroup either const or empty.
*/
public static AColGroup create(IColIndex colIndices, ADictionary dict) {
public static AColGroup create(IColIndex colIndices, IDictionary dict) {
if(dict == null)
return new ColGroupEmpty(colIndices);
else if(dict.getNumberOfValues(colIndices.size()) > 1) {
Expand Down Expand Up @@ -147,7 +147,7 @@ public static AColGroup create(IColIndex cols, double[] values) {
* @param dict The dictionary to contain int the Constant group.
* @return A Constant column group.
*/
public static AColGroup create(int numCols, ADictionary dict) {
public static AColGroup create(int numCols, IDictionary dict) {
if(dict instanceof MatrixBlockDictionary) {
MatrixBlock mbd = ((MatrixBlockDictionary) dict).getMatrixBlock();
if(mbd.getNumColumns() != numCols && mbd.getNumRows() != 1) {
Expand Down Expand Up @@ -444,14 +444,14 @@ protected AColGroup sliceSingleColumn(int idx) {
if(v == 0)
return new ColGroupEmpty(colIndexes);
else {
ADictionary retD = Dictionary.create(new double[] {_dict.getValue(idx)});
IDictionary retD = Dictionary.create(new double[] {_dict.getValue(idx)});
return create(colIndexes, retD);
}
}

@Override
protected AColGroup sliceMultiColumns(int idStart, int idEnd, IColIndex outputCols) {
ADictionary retD = _dict.sliceOutColumnRange(idStart, idEnd, _colIndexes.size());
IDictionary retD = _dict.sliceOutColumnRange(idStart, idEnd, _colIndexes.size());
return create(outputCols, retD);
}

Expand All @@ -467,7 +467,7 @@ public long getNumberNonZeros(int nRows) {

@Override
public AColGroup replace(double pattern, double replace) {
ADictionary replaced = _dict.replace(pattern, replace, _colIndexes.size());
IDictionary replaced = _dict.replace(pattern, replace, _colIndexes.size());
return create(_colIndexes, replaced);
}

Expand Down Expand Up @@ -517,7 +517,7 @@ public CM_COV_Object centralMoment(CMOperator op, int nRows) {

@Override
public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
ADictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.size());
IDictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.size());
if(d == null)
return ColGroupEmpty.create(max);
else
Expand All @@ -534,12 +534,12 @@ protected AColGroup copyAndSet(IColIndex colIndexes, double[] newDictionary) {
return create(colIndexes, Dictionary.create(newDictionary));
}

protected AColGroup copyAndSet(IColIndex colIndexes, ADictionary newDictionary) {
protected AColGroup copyAndSet(IColIndex colIndexes, IDictionary newDictionary) {
return create(colIndexes, newDictionary);
}

@Override
protected AColGroup allocateRightMultiplication(MatrixBlock right, IColIndex colIndexes, ADictionary preAgg) {
protected AColGroup allocateRightMultiplication(MatrixBlock right, IColIndex colIndexes, IDictionary preAgg) {
if(colIndexes != null && preAgg != null)
return create(colIndexes, preAgg);
else
Expand All @@ -548,7 +548,7 @@ protected AColGroup allocateRightMultiplication(MatrixBlock right, IColIndex col

public static ColGroupConst read(DataInput in) throws IOException {
IColIndex cols = ColIndexFactory.read(in);
ADictionary dict = DictionaryFactory.read(in);
IDictionary dict = DictionaryFactory.read(in);
return new ColGroupConst(cols, dict);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
Expand Down Expand Up @@ -62,7 +62,7 @@ public class ColGroupDDC extends APreAgg implements IMapToDataGroup {

protected final AMapToData _data;

private ColGroupDDC(IColIndex colIndexes, ADictionary dict, AMapToData data, int[] cachedCounts) {
private ColGroupDDC(IColIndex colIndexes, IDictionary dict, AMapToData data, int[] cachedCounts) {
super(colIndexes, dict, cachedCounts);
_data = data;

Expand All @@ -77,7 +77,7 @@ private ColGroupDDC(IColIndex colIndexes, ADictionary dict, AMapToData data, int

}

public static AColGroup create(IColIndex colIndexes, ADictionary dict, AMapToData data, int[] cachedCounts) {
public static AColGroup create(IColIndex colIndexes, IDictionary dict, AMapToData data, int[] cachedCounts) {
if(data.getUnique() == 1)
return ColGroupConst.create(colIndexes, dict);
else if(dict == null)
Expand Down Expand Up @@ -431,7 +431,7 @@ public AColGroup unaryOperation(UnaryOperator op) {

@Override
public AColGroup binaryRowOpLeft(BinaryOperator op, double[] v, boolean isRowSafe) {
ADictionary ret = _dict.binOpLeft(op, v, _colIndexes);
IDictionary ret = _dict.binOpLeft(op, v, _colIndexes);
return create(_colIndexes, ret, _data, getCachedCounts());
}

Expand All @@ -442,7 +442,7 @@ public AColGroup binaryRowOpRight(BinaryOperator op, double[] v, boolean isRowSa
final double[] reference = ColGroupUtils.binaryDefRowRight(op, v, _colIndexes);
return ColGroupDDCFOR.create(_colIndexes, _dict, _data, getCachedCounts(), reference);
}
final ADictionary ret = _dict.binOpRight(op, v, _colIndexes);
final IDictionary ret = _dict.binOpRight(op, v, _colIndexes);
return create(_colIndexes, ret, _data, getCachedCounts());
}

Expand All @@ -454,7 +454,7 @@ public void write(DataOutput out) throws IOException {

public static ColGroupDDC read(DataInput in) throws IOException {
IColIndex cols = ColIndexFactory.read(in);
ADictionary dict = DictionaryFactory.read(in);
IDictionary dict = DictionaryFactory.read(in);
AMapToData data = MapToFactory.readIn(in);
return new ColGroupDDC(cols, dict, data, null);
}
Expand Down Expand Up @@ -494,7 +494,7 @@ public boolean containsValue(double pattern) {
}

@Override
protected AColGroup allocateRightMultiplication(MatrixBlock right, IColIndex colIndexes, ADictionary preAgg) {
protected AColGroup allocateRightMultiplication(MatrixBlock right, IColIndex colIndexes, IDictionary preAgg) {
if(preAgg != null)
return create(colIndexes, preAgg, _data, getCachedCounts());
else
Expand All @@ -512,7 +512,7 @@ public AColGroup sliceRows(int rl, int ru) {
}

@Override
protected AColGroup copyAndSet(IColIndex colIndexes, ADictionary newDictionary) {
protected AColGroup copyAndSet(IColIndex colIndexes, IDictionary newDictionary) {
return create(colIndexes, newDictionary, _data, getCachedCounts());
}

Expand Down
Loading

0 comments on commit 1a51720

Please sign in to comment.