Skip to content

Commit

Permalink
[SYSTEMDS-3783] Fix wsigmoid rewrite test setup
Browse files Browse the repository at this point in the history
The recent addition of various rewrite tests for code coverage left a
FIXME on the wsigmoid test which gave incorrect results for all
variants without transpose. After double checking, it turns out the
test setup was wrong in the assumptions when the rewrite should apply
(missing transpose) and how the shapes of involved matrices look like.
  • Loading branch information
mboehm7 committed Oct 21, 2024
1 parent 29b3c61 commit 9b6a96d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

package org.apache.sysds.test.functions.rewrite;

import java.util.HashMap;

import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
Expand All @@ -32,9 +35,8 @@ public class RewriteSimplifyWeightedSigmoidMMChainsTest extends AutomatedTestBas
private static final String TEST_CLASS_DIR =
TEST_DIR + RewriteSimplifyWeightedSigmoidMMChainsTest.class.getSimpleName() + "/";

private static final int rows = 100;
private static final int rows = 150;
private static final int cols = 100;
//private static final double eps = Math.pow(10, -10);

@Override
public void setUp() {
Expand Down Expand Up @@ -125,8 +127,9 @@ private void testRewriteSimplifyWeightedSigmoidMMChains(int ID, boolean rewrites
OptimizerUtils.ALLOW_OPERATOR_FUSION = rewrites;

//create matrices
double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.80d, 3);
double[][] Y = getRandomMatrix(rows, cols, -1, 1, 0.70d, 4);
int rank = 50;
double[][] X = getRandomMatrix(cols, rank, -1, 1, 0.80d, 3);
double[][] Y = getRandomMatrix(rows, rank, -1, 1, 0.70d, 4);
double[][] W = getRandomMatrix(rows, cols, -1, 1, 0.60d, 5);
writeInputMatrixWithMTD("X", X, true);
writeInputMatrixWithMTD("Y", Y, true);
Expand All @@ -136,10 +139,9 @@ private void testRewriteSimplifyWeightedSigmoidMMChains(int ID, boolean rewrites
runRScript(true);

//compare matrices
// FIXME
// HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
// HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R");
// compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R");
TestUtils.compareMatrices(dmlfile, rfile, 1e-8, "Stat-DML", "Stat-R");

if(rewrites)
Assert.assertTrue(heavyHittersContainsString("wsigmoid"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Y = read($2)
W = read($3)
type = $4

if( type > 4 )
X = t(X);

# Perform operations
if(type == 1){
Expand Down

0 comments on commit 9b6a96d

Please sign in to comment.