diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java index 63e632c895c..7d7a7743c35 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java @@ -35,6 +35,7 @@ import org.apache.sysds.hops.fedplanner.FTypes.FType; import org.apache.sysds.lops.Lop; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.controlprogram.ParForProgramBlock.PDataPartitionFormat; import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysds.runtime.controlprogram.federated.FederatedRange; @@ -592,6 +593,18 @@ protected MatrixBlock reconstructByLineage(LineageItem li) throws IOException { .acquireReadAndRelease(); } + @Override + public boolean isCompressed(){ + if(super.isCompressed()) + return true; + else if(_partitionInMemory instanceof CompressedMatrixBlock){ + setCompressedSize(_partitionInMemory.estimateSizeInMemory()); + return true; + } + else + return false; + } + @Override public String toString(){ StringBuilder sb = new StringBuilder(super.toString()); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java index c17fcb0ad40..4e40d62f4a9 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java @@ -61,39 +61,36 @@ public static AggregateBinaryCPInstruction parseInstruction(String str) { CPOperand in2 = new CPOperand(parts[2]); CPOperand out = new CPOperand(parts[3]); int k = Integer.parseInt(parts[4]); - AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(k); + AggregateBinaryOperator op = InstructionUtils.getMatMultOperator(k); if(numFields == 6) { - boolean isLeftTransposed = Boolean.parseBoolean(parts[5]); - boolean isRightTransposed = Boolean.parseBoolean(parts[6]); - return new AggregateBinaryCPInstruction(aggbin, - in1, in2, out, opcode, str, isLeftTransposed, isRightTransposed); + boolean lt = Boolean.parseBoolean(parts[5]); + boolean rt = Boolean.parseBoolean(parts[6]); + return new AggregateBinaryCPInstruction(op, in1, in2, out, opcode, str, lt, rt); } - return new AggregateBinaryCPInstruction(aggbin, in1, in2, out, opcode, str); + return new AggregateBinaryCPInstruction(op, in1, in2, out, opcode, str); } @Override public void processInstruction(ExecutionContext ec) { + MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName()); + MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName()); // check compressed inputs - final boolean comp1 = ec.getMatrixObject(input1.getName()).isCompressed(); - final boolean comp2 = ec.getMatrixObject(input2.getName()).isCompressed(); + final boolean comp1 = matBlock1 instanceof CompressedMatrixBlock; + final boolean comp2 = matBlock2 instanceof CompressedMatrixBlock; + if(comp1 || comp2) - processCompressedAggregateBinary(ec, comp1, comp2); + processCompressedAggregateBinary(ec, matBlock1, matBlock2, comp1, comp2); else if(transposeLeft || transposeRight) - processTransposedFusedAggregateBinary(ec); + processTransposedFusedAggregateBinary(ec, matBlock1, matBlock2); else - processNormal(ec); - } + processNormal(ec, matBlock1, matBlock2); - private void processNormal(ExecutionContext ec) { - // get inputs - MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName()); - MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName()); + } + private void processNormal(ExecutionContext ec, MatrixBlock matBlock1, MatrixBlock matBlock2) { // compute matrix multiplication AggregateBinaryOperator ab_op = (AggregateBinaryOperator) _optr; - MatrixBlock ret; - - ret = matBlock1.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op); + MatrixBlock ret = matBlock1.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op); // release inputs/outputs ec.releaseMatrixInput(input1.getName()); @@ -101,9 +98,9 @@ private void processNormal(ExecutionContext ec) { ec.setMatrixOutput(output.getName(), ret); } - private void processTransposedFusedAggregateBinary(ExecutionContext ec) { - MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName()); - MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName()); + private void processTransposedFusedAggregateBinary(ExecutionContext ec, MatrixBlock matBlock1, + MatrixBlock matBlock2) { + // compute matrix multiplication AggregateBinaryOperator ab_op = (AggregateBinaryOperator) _optr; MatrixBlock ret; @@ -127,9 +124,9 @@ private void processTransposedFusedAggregateBinary(ExecutionContext ec) { ec.setMatrixOutput(output.getName(), ret); } - private void processCompressedAggregateBinary(ExecutionContext ec, boolean c1, boolean c2) { - MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName()); - MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName()); + private void processCompressedAggregateBinary(ExecutionContext ec, MatrixBlock matBlock1, MatrixBlock matBlock2, + boolean c1, boolean c2) { + // compute matrix multiplication AggregateBinaryOperator ab_op = (AggregateBinaryOperator) _optr; MatrixBlock ret;