Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYSTEMDS-3788] Modify the FederatedCovarianceTest to account for weighted covariance #2136

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ public class FederatedCovarianceTest extends AutomatedTestBase {

private final static String TEST_NAME1 = "FederatedCovarianceTest";
private final static String TEST_NAME2 = "FederatedCovarianceAlignedTest";
private final static String TEST_NAME3 = "FederatedCovarianceWeightedTest";
private final static String TEST_NAME4 = "FederatedCovarianceAlignedWeightedTest";
private final static String TEST_NAME5 = "FederatedCovarianceAllAlignedWeightedTest";
private final static String TEST_DIR = "functions/federated/";
private static final String TEST_CLASS_DIR = TEST_DIR + FederatedCovarianceTest.class.getSimpleName() + "/";

Expand All @@ -64,19 +67,37 @@ public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S.scalar"}));
addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S.scalar"}));
addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S.scalar"}));
addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"S.scalar"}));
addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {"S.scalar"}));
}

@Test
public void testCovCP() {
runCovTest(ExecMode.SINGLE_NODE, false);
runCovarianceTest(ExecMode.SINGLE_NODE, false);
}

@Test
public void testAlignedCovCP() {
runCovTest(ExecMode.SINGLE_NODE, true);
runCovarianceTest(ExecMode.SINGLE_NODE, true);
}

private void runCovTest(ExecMode execMode, boolean alignedFedInput) {
@Test
public void testCovarianceWeightedCP() {
runWeightedCovarianceTest(ExecMode.SINGLE_NODE, false, false);
}

@Test
public void testAlignedCovarianceWeightedCP() {
runWeightedCovarianceTest(ExecMode.SINGLE_NODE, true, false);
}

@Test
public void testAllAlignedCovarianceWeightedCP() {
runWeightedCovarianceTest(ExecMode.SINGLE_NODE, true, true);
}

private void runCovarianceTest(ExecMode execMode, boolean alignedFedInput) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;

Expand Down Expand Up @@ -190,4 +211,221 @@ private void runCovTest(ExecMode execMode, boolean alignedFedInput) {
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
}

private void runWeightedCovarianceTest(ExecMode execMode, boolean alignedInput, boolean alignedWeights) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;

if(rtplatform == ExecMode.SPARK)
DMLScript.USE_LOCAL_SPARK_CONFIG = true;

String TEST_NAME = !alignedInput ? TEST_NAME3 : (!alignedWeights ? TEST_NAME4 : TEST_NAME5);
getAndLoadTestConfiguration(TEST_NAME);

String HOME = SCRIPT_DIR + TEST_DIR;

int r = rows / 4;
int c = cols;

fullDMLScriptName = "";

// Create 4 random 5x1 matrices
double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);

// Create a 20x1 weights matrix
double[][] W = getRandomMatrix(rows, c, 0, 1, 1, 3);

MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c);
writeInputMatrixWithMTD("X1", X1, false, mc);
writeInputMatrixWithMTD("X2", X2, false, mc);
writeInputMatrixWithMTD("X3", X3, false, mc);
writeInputMatrixWithMTD("X4", X4, false, mc);

writeInputMatrixWithMTD("W", W, false, new MatrixCharacteristics(rows, cols, blocksize, r * c));

// empty script name because we don't execute any script, just start the worker
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();

Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
Process t4 = startLocalFedWorker(port4);

try {
if(!isAlive(t1, t2, t3, t4))
throw new RuntimeException("Failed starting federated worker");

rtplatform = execMode;
if(rtplatform == ExecMode.SPARK) {
System.out.println(7);
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
}

TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);

if (alignedInput) {
// Create 4 random 5x1 matrices
double[][] Y1 = getRandomMatrix(r, c, 1, 5, 1, 3);
double[][] Y2 = getRandomMatrix(r, c, 1, 5, 1, 7);
double[][] Y3 = getRandomMatrix(r, c, 1, 5, 1, 8);
double[][] Y4 = getRandomMatrix(r, c, 1, 5, 1, 9);

writeInputMatrixWithMTD("Y1", Y1, false, mc);
writeInputMatrixWithMTD("Y2", Y2, false, mc);
writeInputMatrixWithMTD("Y3", Y3, false, mc);
writeInputMatrixWithMTD("Y4", Y4, false, mc);

if (!alignedWeights) {
// Run reference dml script with a normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {
"-stats", "100", "-args",
input("X1"),
input("X2"),
input("X3"),
input("X4"),

input("Y1"),
input("Y2"),
input("Y3"),
input("Y4"),

input("W"),
expected("S")
};
runTest(null);

// Run the dml script with federated matrices
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-stats", "100", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
"in_Y1=" + TestUtils.federatedAddress(port1, input("Y1")),

"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
"in_Y2=" + TestUtils.federatedAddress(port2, input("Y2")),

"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
"in_Y3=" + TestUtils.federatedAddress(port3, input("Y3")),

"in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
"in_Y4=" + TestUtils.federatedAddress(port4, input("Y4")),

"in_W1=" + input("W"),
"rows=" + rows, "cols=" + cols,
"out_S=" + output("S")};
runTest(null);
}
else {
double[][] W1 = getRandomMatrix(r, c, 0, 1, 1, 3);
double[][] W2 = getRandomMatrix(r, c, 0, 1, 1, 7);
double[][] W3 = getRandomMatrix(r, c, 0, 1, 1, 8);
double[][] W4 = getRandomMatrix(r, c, 0, 1, 1, 9);

writeInputMatrixWithMTD("W1", W1, false, mc);
writeInputMatrixWithMTD("W2", W2, false, mc);
writeInputMatrixWithMTD("W3", W3, false, mc);
writeInputMatrixWithMTD("W4", W4, false, mc);

// Run reference dml script with a normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {
"-stats", "100", "-args",
input("X1"),
input("X2"),
input("X3"),
input("X4"),

input("Y1"),
input("Y2"),
input("Y3"),
input("Y4"),

input("W1"),
input("W2"),
input("W3"),
input("W4"),

expected("S")
};
runTest(null);

// Run the dml script with federated matrices and weights
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-stats", "100", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
"in_Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
"in_W1=" + TestUtils.federatedAddress(port1, input("W1")),

"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
"in_Y2=" + TestUtils.federatedAddress(port2, input("Y2")),
"in_W2=" + TestUtils.federatedAddress(port2, input("W2")),

"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
"in_Y3=" + TestUtils.federatedAddress(port3, input("Y3")),
"in_W3=" + TestUtils.federatedAddress(port3, input("W3")),

"in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
"in_Y4=" + TestUtils.federatedAddress(port4, input("Y4")),
"in_W4=" + TestUtils.federatedAddress(port4, input("W4")),

"rows=" + rows, "cols=" + cols,
"out_S=" + output("S")};
runTest(null);
}

}
else {
// Create a random 20x1 input matrix
double[][] Y = getRandomMatrix(rows, c, 1, 5, 1, 3);
writeInputMatrixWithMTD("Y", Y, false, new MatrixCharacteristics(rows, cols, blocksize, r * c));

// Run reference dml script with a normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {
"-stats", "100", "-args",
input("X1"),
input("X2"),
input("X3"),
input("X4"),

input("Y"), input("W"), expected("S")
};
runTest(null);

// Run the dml script with a federated matrix
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-stats", "100", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
"in_X4=" + TestUtils.federatedAddress(port4, input("X4")),

"in_W1=" + input("W"),
"Y=" + input("Y"),

"rows=" + rows,
"cols=" + cols,
"out_S=" + output("S")};
runTest(null);
}

// compare via files
compareResults(1e-2);
Assert.assertTrue(heavyHittersContainsString("fed_cov"));

}
finally {
TestUtils.shutdownThreads(t1, t2, t3, t4);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

# 5x1 on 4 workers -> 20x1
X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));

# 5x1 on 4 workers -> 20x1
Y = federated(addresses=list($in_Y1, $in_Y2, $in_Y3, $in_Y4),
ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));

W = read($in_W1); # 20x1

s = cov(X, Y, W);
write(s, $out_S);
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

X = rbind(read($1), read($2), read($3), read($4)); # 20x1
Y = rbind(read($5), read($6), read($7), read($8)); # 20x1
W = read($9); # 20x1

s = cov(X, Y, W);
write(s, $10);
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

# 5x1 on 4 workers -> 20x1
X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));

# 5x1 on 4 workers -> 20x1
Y = federated(addresses=list($in_Y1, $in_Y2, $in_Y3, $in_Y4),
ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));

# 5x1 on 4 workers -> 20x1
W = federated(addresses=list($in_W1, $in_W2, $in_W3, $in_W4),
ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));

s = cov(X, Y, W);
write(s, $out_S);
Loading
Loading