Skip to content

Commit

Permalink
Modify Federated Covariance Test to account for the weighted covarian…
Browse files Browse the repository at this point in the history
…ce cases
  • Loading branch information
gaturchenko committed Nov 13, 2024
1 parent 553d30c commit b8a63a3
Showing 1 changed file with 241 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ public class FederatedCovarianceTest extends AutomatedTestBase {

private final static String TEST_NAME1 = "FederatedCovarianceTest";
private final static String TEST_NAME2 = "FederatedCovarianceAlignedTest";
private final static String TEST_NAME3 = "FederatedCovarianceWeightedTest";
private final static String TEST_NAME4 = "FederatedCovarianceAlignedWeightedTest";
private final static String TEST_NAME5 = "FederatedCovarianceAllAlignedWeightedTest";
private final static String TEST_DIR = "functions/federated/";
private static final String TEST_CLASS_DIR = TEST_DIR + FederatedCovarianceTest.class.getSimpleName() + "/";

Expand All @@ -64,19 +67,37 @@ public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S.scalar"}));
addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S.scalar"}));
addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S.scalar"}));
addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"S.scalar"}));
addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {"S.scalar"}));
}

@Test
public void testCovCP() {
runCovTest(ExecMode.SINGLE_NODE, false);
runCovarianceTest(ExecMode.SINGLE_NODE, false);
}

@Test
public void testAlignedCovCP() {
runCovTest(ExecMode.SINGLE_NODE, true);
runCovarianceTest(ExecMode.SINGLE_NODE, true);
}

private void runCovTest(ExecMode execMode, boolean alignedFedInput) {
@Test
public void testCovarianceWeightedCP() {
runWeightedCovarianceTest(ExecMode.SINGLE_NODE, false, false);
}

@Test
public void testAlignedCovarianceWeightedCP() {
runWeightedCovarianceTest(ExecMode.SINGLE_NODE, true, false);
}

@Test
public void testAllAlignedCovarianceWeightedCP() {
runWeightedCovarianceTest(ExecMode.SINGLE_NODE, true, true);
}

private void runCovarianceTest(ExecMode execMode, boolean alignedFedInput) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;

Expand Down Expand Up @@ -190,4 +211,221 @@ private void runCovTest(ExecMode execMode, boolean alignedFedInput) {
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
}

private void runWeightedCovarianceTest(ExecMode execMode, boolean alignedInput, boolean alignedWeights) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;

if(rtplatform == ExecMode.SPARK)
DMLScript.USE_LOCAL_SPARK_CONFIG = true;

String TEST_NAME = !alignedInput ? TEST_NAME3 : (!alignedWeights ? TEST_NAME4 : TEST_NAME5);
getAndLoadTestConfiguration(TEST_NAME);

String HOME = SCRIPT_DIR + TEST_DIR;

int r = rows / 4;
int c = cols;

fullDMLScriptName = "";

// Create 4 random 5x1 matrices
double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);

// Create a 20x1 weights matrix
double[][] W = getRandomMatrix(rows, c, 0, 1, 1, 3);

MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c);
writeInputMatrixWithMTD("X1", X1, false, mc);
writeInputMatrixWithMTD("X2", X2, false, mc);
writeInputMatrixWithMTD("X3", X3, false, mc);
writeInputMatrixWithMTD("X4", X4, false, mc);

writeInputMatrixWithMTD("W", W, false, new MatrixCharacteristics(rows, cols, blocksize, r * c));

// empty script name because we don't execute any script, just start the worker
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();

Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
Process t4 = startLocalFedWorker(port4);

try {
if(!isAlive(t1, t2, t3, t4))
throw new RuntimeException("Failed starting federated worker");

rtplatform = execMode;
if(rtplatform == ExecMode.SPARK) {
System.out.println(7);
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
}

TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);

if (alignedInput) {
// Create 4 random 5x1 matrices
double[][] Y1 = getRandomMatrix(r, c, 1, 5, 1, 3);
double[][] Y2 = getRandomMatrix(r, c, 1, 5, 1, 7);
double[][] Y3 = getRandomMatrix(r, c, 1, 5, 1, 8);
double[][] Y4 = getRandomMatrix(r, c, 1, 5, 1, 9);

writeInputMatrixWithMTD("Y1", Y1, false, mc);
writeInputMatrixWithMTD("Y2", Y2, false, mc);
writeInputMatrixWithMTD("Y3", Y3, false, mc);
writeInputMatrixWithMTD("Y4", Y4, false, mc);

if (!alignedWeights) {
// Run reference dml script with a normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {
"-stats", "100", "-args",
input("X1"),
input("X2"),
input("X3"),
input("X4"),

input("Y1"),
input("Y2"),
input("Y3"),
input("Y4"),

input("W"),
expected("S")
};
runTest(null);

// Run the dml script with federated matrices
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-stats", "100", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
"in_Y1=" + TestUtils.federatedAddress(port1, input("Y1")),

"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
"in_Y2=" + TestUtils.federatedAddress(port2, input("Y2")),

"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
"in_Y3=" + TestUtils.federatedAddress(port3, input("Y3")),

"in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
"in_Y4=" + TestUtils.federatedAddress(port4, input("Y4")),

"in_W1=" + input("W"),
"rows=" + rows, "cols=" + cols,
"out_S=" + output("S")};
runTest(null);
}
else {
double[][] W1 = getRandomMatrix(r, c, 0, 1, 1, 3);
double[][] W2 = getRandomMatrix(r, c, 0, 1, 1, 7);
double[][] W3 = getRandomMatrix(r, c, 0, 1, 1, 8);
double[][] W4 = getRandomMatrix(r, c, 0, 1, 1, 9);

writeInputMatrixWithMTD("W1", W1, false, mc);
writeInputMatrixWithMTD("W2", W2, false, mc);
writeInputMatrixWithMTD("W3", W3, false, mc);
writeInputMatrixWithMTD("W4", W4, false, mc);

// Run reference dml script with a normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {
"-stats", "100", "-args",
input("X1"),
input("X2"),
input("X3"),
input("X4"),

input("Y1"),
input("Y2"),
input("Y3"),
input("Y4"),

input("W1"),
input("W2"),
input("W3"),
input("W4"),

expected("S")
};
runTest(null);

// Run the dml script with federated matrices and weights
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-stats", "100", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
"in_Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
"in_W1=" + TestUtils.federatedAddress(port1, input("W1")),

"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
"in_Y2=" + TestUtils.federatedAddress(port2, input("Y2")),
"in_W2=" + TestUtils.federatedAddress(port2, input("W2")),

"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
"in_Y3=" + TestUtils.federatedAddress(port3, input("Y3")),
"in_W3=" + TestUtils.federatedAddress(port3, input("W3")),

"in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
"in_Y4=" + TestUtils.federatedAddress(port4, input("Y4")),
"in_W4=" + TestUtils.federatedAddress(port4, input("W4")),

"rows=" + rows, "cols=" + cols,
"out_S=" + output("S")};
runTest(null);
}

}
else {
// Create a random 20x1 input matrix
double[][] Y = getRandomMatrix(rows, c, 1, 5, 1, 3);
writeInputMatrixWithMTD("Y", Y, false, new MatrixCharacteristics(rows, cols, blocksize, r * c));

// Run reference dml script with a normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {
"-stats", "100", "-args",
input("X1"),
input("X2"),
input("X3"),
input("X4"),

input("Y"), input("W"), expected("S")
};
runTest(null);

// Run the dml script with a federated matrix
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-stats", "100", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
"in_X4=" + TestUtils.federatedAddress(port4, input("X4")),

"in_W1=" + input("W"),
"Y=" + input("Y"),

"rows=" + rows,
"cols=" + cols,
"out_S=" + output("S")};
runTest(null);
}

// compare via files
compareResults(1e-2);
Assert.assertTrue(heavyHittersContainsString("fed_cov"));

}
finally {
TestUtils.shutdownThreads(t1, t2, t3, t4);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
}
}

0 comments on commit b8a63a3

Please sign in to comment.