diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java index 136cdde7f96..24e9fc3c055 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java @@ -43,6 +43,9 @@ public class FederatedCovarianceTest extends AutomatedTestBase { private final static String TEST_NAME1 = "FederatedCovarianceTest"; private final static String TEST_NAME2 = "FederatedCovarianceAlignedTest"; + private final static String TEST_NAME3 = "FederatedCovarianceWeightedTest"; + private final static String TEST_NAME4 = "FederatedCovarianceAlignedWeightedTest"; + private final static String TEST_NAME5 = "FederatedCovarianceAllAlignedWeightedTest"; private final static String TEST_DIR = "functions/federated/"; private static final String TEST_CLASS_DIR = TEST_DIR + FederatedCovarianceTest.class.getSimpleName() + "/"; @@ -64,19 +67,37 @@ public void setUp() { TestUtils.clearAssertionInformation(); addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S.scalar"})); addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S.scalar"})); + addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S.scalar"})); + addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"S.scalar"})); + addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {"S.scalar"})); } @Test public void testCovCP() { - runCovTest(ExecMode.SINGLE_NODE, false); + runCovarianceTest(ExecMode.SINGLE_NODE, false); } @Test public void testAlignedCovCP() { - runCovTest(ExecMode.SINGLE_NODE, true); + runCovarianceTest(ExecMode.SINGLE_NODE, true); } - private void runCovTest(ExecMode execMode, boolean alignedFedInput) { + @Test + public void testCovarianceWeightedCP() { + runWeightedCovarianceTest(ExecMode.SINGLE_NODE, false, false); + } + + @Test + public void testAlignedCovarianceWeightedCP() { + runWeightedCovarianceTest(ExecMode.SINGLE_NODE, true, false); + } + + @Test + public void testAllAlignedCovarianceWeightedCP() { + runWeightedCovarianceTest(ExecMode.SINGLE_NODE, true, true); + } + + private void runCovarianceTest(ExecMode execMode, boolean alignedFedInput) { boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; ExecMode platformOld = rtplatform; @@ -190,4 +211,221 @@ private void runCovTest(ExecMode execMode, boolean alignedFedInput) { DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } } + + private void runWeightedCovarianceTest(ExecMode execMode, boolean alignedInput, boolean alignedWeights) { + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + ExecMode platformOld = rtplatform; + + if(rtplatform == ExecMode.SPARK) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + String TEST_NAME = !alignedInput ? TEST_NAME3 : (!alignedWeights ? TEST_NAME4 : TEST_NAME5); + getAndLoadTestConfiguration(TEST_NAME); + + String HOME = SCRIPT_DIR + TEST_DIR; + + int r = rows / 4; + int c = cols; + + fullDMLScriptName = ""; + + // Create 4 random 5x1 matrices + double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3); + double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7); + double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8); + double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9); + + // Create a 20x1 weights matrix + double[][] W = getRandomMatrix(rows, c, 0, 1, 1, 3); + + MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c); + writeInputMatrixWithMTD("X1", X1, false, mc); + writeInputMatrixWithMTD("X2", X2, false, mc); + writeInputMatrixWithMTD("X3", X3, false, mc); + writeInputMatrixWithMTD("X4", X4, false, mc); + + writeInputMatrixWithMTD("W", W, false, new MatrixCharacteristics(rows, cols, blocksize, r * c)); + + // empty script name because we don't execute any script, just start the worker + fullDMLScriptName = ""; + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + int port3 = getRandomAvailablePort(); + int port4 = getRandomAvailablePort(); + + Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); + Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); + Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); + Process t4 = startLocalFedWorker(port4); + + try { + if(!isAlive(t1, t2, t3, t4)) + throw new RuntimeException("Failed starting federated worker"); + + rtplatform = execMode; + if(rtplatform == ExecMode.SPARK) { + System.out.println(7); + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + + TestConfiguration config = availableTestConfigurations.get(TEST_NAME); + loadTestConfiguration(config); + + if (alignedInput) { + // Create 4 random 5x1 matrices + double[][] Y1 = getRandomMatrix(r, c, 1, 5, 1, 3); + double[][] Y2 = getRandomMatrix(r, c, 1, 5, 1, 7); + double[][] Y3 = getRandomMatrix(r, c, 1, 5, 1, 8); + double[][] Y4 = getRandomMatrix(r, c, 1, 5, 1, 9); + + writeInputMatrixWithMTD("Y1", Y1, false, mc); + writeInputMatrixWithMTD("Y2", Y2, false, mc); + writeInputMatrixWithMTD("Y3", Y3, false, mc); + writeInputMatrixWithMTD("Y4", Y4, false, mc); + + if (!alignedWeights) { + // Run reference dml script with a normal matrix + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + programArgs = new String[] { + "-stats", "100", "-args", + input("X1"), + input("X2"), + input("X3"), + input("X4"), + + input("Y1"), + input("Y2"), + input("Y3"), + input("Y4"), + + input("W"), + expected("S") + }; + runTest(null); + + // Run the dml script with federated matrices + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "100", "-nvargs", + "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), + "in_Y1=" + TestUtils.federatedAddress(port1, input("Y1")), + + "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), + "in_Y2=" + TestUtils.federatedAddress(port2, input("Y2")), + + "in_X3=" + TestUtils.federatedAddress(port3, input("X3")), + "in_Y3=" + TestUtils.federatedAddress(port3, input("Y3")), + + "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), + "in_Y4=" + TestUtils.federatedAddress(port4, input("Y4")), + + "in_W1=" + input("W"), + "rows=" + rows, "cols=" + cols, + "out_S=" + output("S")}; + runTest(null); + } + else { + double[][] W1 = getRandomMatrix(r, c, 0, 1, 1, 3); + double[][] W2 = getRandomMatrix(r, c, 0, 1, 1, 7); + double[][] W3 = getRandomMatrix(r, c, 0, 1, 1, 8); + double[][] W4 = getRandomMatrix(r, c, 0, 1, 1, 9); + + writeInputMatrixWithMTD("W1", W1, false, mc); + writeInputMatrixWithMTD("W2", W2, false, mc); + writeInputMatrixWithMTD("W3", W3, false, mc); + writeInputMatrixWithMTD("W4", W4, false, mc); + + // Run reference dml script with a normal matrix + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + programArgs = new String[] { + "-stats", "100", "-args", + input("X1"), + input("X2"), + input("X3"), + input("X4"), + + input("Y1"), + input("Y2"), + input("Y3"), + input("Y4"), + + input("W1"), + input("W2"), + input("W3"), + input("W4"), + + expected("S") + }; + runTest(null); + + // Run the dml script with federated matrices and weights + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "100", "-nvargs", + "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), + "in_Y1=" + TestUtils.federatedAddress(port1, input("Y1")), + "in_W1=" + TestUtils.federatedAddress(port1, input("W1")), + + "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), + "in_Y2=" + TestUtils.federatedAddress(port2, input("Y2")), + "in_W2=" + TestUtils.federatedAddress(port2, input("W2")), + + "in_X3=" + TestUtils.federatedAddress(port3, input("X3")), + "in_Y3=" + TestUtils.federatedAddress(port3, input("Y3")), + "in_W3=" + TestUtils.federatedAddress(port3, input("W3")), + + "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), + "in_Y4=" + TestUtils.federatedAddress(port4, input("Y4")), + "in_W4=" + TestUtils.federatedAddress(port4, input("W4")), + + "rows=" + rows, "cols=" + cols, + "out_S=" + output("S")}; + runTest(null); + } + + } + else { + // Create a random 20x1 input matrix + double[][] Y = getRandomMatrix(rows, c, 1, 5, 1, 3); + writeInputMatrixWithMTD("Y", Y, false, new MatrixCharacteristics(rows, cols, blocksize, r * c)); + + // Run reference dml script with a normal matrix + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + programArgs = new String[] { + "-stats", "100", "-args", + input("X1"), + input("X2"), + input("X3"), + input("X4"), + + input("Y"), input("W"), expected("S") + }; + runTest(null); + + // Run the dml script with a federated matrix + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "100", "-nvargs", + "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), + "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), + "in_X3=" + TestUtils.federatedAddress(port3, input("X3")), + "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), + + "in_W1=" + input("W"), + "Y=" + input("Y"), + + "rows=" + rows, + "cols=" + cols, + "out_S=" + output("S")}; + runTest(null); + } + + // compare via files + compareResults(1e-2); + Assert.assertTrue(heavyHittersContainsString("fed_cov")); + + } + finally { + TestUtils.shutdownThreads(t1, t2, t3, t4); + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } } diff --git a/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTest.dml b/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTest.dml new file mode 100644 index 00000000000..da9db2f4dea --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTest.dml @@ -0,0 +1,35 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# 5x1 on 4 workers -> 20x1 +X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); + +# 5x1 on 4 workers -> 20x1 +Y = federated(addresses=list($in_Y1, $in_Y2, $in_Y3, $in_Y4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); + +W = read($in_W1); # 20x1 + +s = cov(X, Y, W); +write(s, $out_S); \ No newline at end of file diff --git a/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTestReference.dml b/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTestReference.dml new file mode 100644 index 00000000000..ee4062f7e69 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTestReference.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rbind(read($1), read($2), read($3), read($4)); # 20x1 +Y = rbind(read($5), read($6), read($7), read($8)); # 20x1 +W = read($9); # 20x1 + +s = cov(X, Y, W); +write(s, $10); \ No newline at end of file diff --git a/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTest.dml b/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTest.dml new file mode 100644 index 00000000000..22029de451d --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTest.dml @@ -0,0 +1,38 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# 5x1 on 4 workers -> 20x1 +X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); + +# 5x1 on 4 workers -> 20x1 +Y = federated(addresses=list($in_Y1, $in_Y2, $in_Y3, $in_Y4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); + +# 5x1 on 4 workers -> 20x1 +W = federated(addresses=list($in_W1, $in_W2, $in_W3, $in_W4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); + +s = cov(X, Y, W); +write(s, $out_S); \ No newline at end of file diff --git a/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTestReference.dml b/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTestReference.dml new file mode 100644 index 00000000000..10c18f5a333 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTestReference.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rbind(read($1), read($2), read($3), read($4)); # 20x1 +Y = rbind(read($5), read($6), read($7), read($8)); # 20x1 +W = rbind(read($9), read($10), read($11), read($12)); # 20x1 + +s = cov(X, Y, W); +write(s, $13); \ No newline at end of file diff --git a/src/test/scripts/functions/federated/FederatedCovarianceWeightedTest.dml b/src/test/scripts/functions/federated/FederatedCovarianceWeightedTest.dml new file mode 100644 index 00000000000..3ba2d5b15f8 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedCovarianceWeightedTest.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# 5x1 on 4 workers -> 20x1 +X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); + +Y = read($Y); # 20x1 +W = read($in_W1); # 20x1 + +s = cov(X, Y, W); +write(s, $out_S); \ No newline at end of file diff --git a/src/test/scripts/functions/federated/FederatedCovarianceWeightedTestReference.dml b/src/test/scripts/functions/federated/FederatedCovarianceWeightedTestReference.dml new file mode 100644 index 00000000000..db1dc7c5265 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedCovarianceWeightedTestReference.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rbind(read($1), read($2), read($3), read($4)); # 20x1 +Y = read($5); # 20x1 +W = read($6); # 20x1 + +s = cov(X, Y, W); +write(s, $7); \ No newline at end of file