Skip to content

Commit

Permalink
[MINOR] Cleanup AggregateBinary and MatrixObject
Browse files Browse the repository at this point in the history
Fix minor issue in Aggregate Binary when we have an intermediate matrix
that is compressed but it was not through our normal compression
framework.
  • Loading branch information
Baunsgaard committed Aug 8, 2023
1 parent 8dbfc23 commit 89f2ad1
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.controlprogram.ParForProgramBlock.PDataPartitionFormat;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
Expand Down Expand Up @@ -592,6 +593,18 @@ protected MatrixBlock reconstructByLineage(LineageItem li) throws IOException {
.acquireReadAndRelease();
}

@Override
public boolean isCompressed(){
if(super.isCompressed())
return true;
else if(_partitionInMemory instanceof CompressedMatrixBlock){
setCompressedSize(_partitionInMemory.estimateSizeInMemory());
return true;
}
else
return false;
}

@Override
public String toString(){
StringBuilder sb = new StringBuilder(super.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,49 +61,46 @@ public static AggregateBinaryCPInstruction parseInstruction(String str) {
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[3]);
int k = Integer.parseInt(parts[4]);
AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(k);
AggregateBinaryOperator op = InstructionUtils.getMatMultOperator(k);
if(numFields == 6) {
boolean isLeftTransposed = Boolean.parseBoolean(parts[5]);
boolean isRightTransposed = Boolean.parseBoolean(parts[6]);
return new AggregateBinaryCPInstruction(aggbin,
in1, in2, out, opcode, str, isLeftTransposed, isRightTransposed);
boolean lt = Boolean.parseBoolean(parts[5]);
boolean rt = Boolean.parseBoolean(parts[6]);
return new AggregateBinaryCPInstruction(op, in1, in2, out, opcode, str, lt, rt);
}
return new AggregateBinaryCPInstruction(aggbin, in1, in2, out, opcode, str);
return new AggregateBinaryCPInstruction(op, in1, in2, out, opcode, str);
}

@Override
public void processInstruction(ExecutionContext ec) {
MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
// check compressed inputs
final boolean comp1 = ec.getMatrixObject(input1.getName()).isCompressed();
final boolean comp2 = ec.getMatrixObject(input2.getName()).isCompressed();
final boolean comp1 = matBlock1 instanceof CompressedMatrixBlock;
final boolean comp2 = matBlock2 instanceof CompressedMatrixBlock;

if(comp1 || comp2)
processCompressedAggregateBinary(ec, comp1, comp2);
processCompressedAggregateBinary(ec, matBlock1, matBlock2, comp1, comp2);
else if(transposeLeft || transposeRight)
processTransposedFusedAggregateBinary(ec);
processTransposedFusedAggregateBinary(ec, matBlock1, matBlock2);
else
processNormal(ec);
}
processNormal(ec, matBlock1, matBlock2);

private void processNormal(ExecutionContext ec) {
// get inputs
MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
}

private void processNormal(ExecutionContext ec, MatrixBlock matBlock1, MatrixBlock matBlock2) {
// compute matrix multiplication
AggregateBinaryOperator ab_op = (AggregateBinaryOperator) _optr;
MatrixBlock ret;

ret = matBlock1.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op);
MatrixBlock ret = matBlock1.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op);

// release inputs/outputs
ec.releaseMatrixInput(input1.getName());
ec.releaseMatrixInput(input2.getName());
ec.setMatrixOutput(output.getName(), ret);
}

private void processTransposedFusedAggregateBinary(ExecutionContext ec) {
MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
private void processTransposedFusedAggregateBinary(ExecutionContext ec, MatrixBlock matBlock1,
MatrixBlock matBlock2) {

// compute matrix multiplication
AggregateBinaryOperator ab_op = (AggregateBinaryOperator) _optr;
MatrixBlock ret;
Expand All @@ -127,9 +124,9 @@ private void processTransposedFusedAggregateBinary(ExecutionContext ec) {
ec.setMatrixOutput(output.getName(), ret);
}

private void processCompressedAggregateBinary(ExecutionContext ec, boolean c1, boolean c2) {
MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
private void processCompressedAggregateBinary(ExecutionContext ec, MatrixBlock matBlock1, MatrixBlock matBlock2,
boolean c1, boolean c2) {

// compute matrix multiplication
AggregateBinaryOperator ab_op = (AggregateBinaryOperator) _optr;
MatrixBlock ret;
Expand Down

0 comments on commit 89f2ad1

Please sign in to comment.