From 7a9ecff8014c9ba490910f7058082d2faba42ea3 Mon Sep 17 00:00:00 2001 From: min-guk Date: Sun, 17 Nov 2024 14:04:21 +0100 Subject: [PATCH] [SYSTEMDS-3790] Restore federated planner tests (all, heuristic) Closes #2139. --- .github/workflows/javaTests.yml | 2 +- .../apache/sysds/hops/fedplanner/FTypes.java | 9 +- .../FederatedDynamicPlanningTest.java | 176 +++++++++ .../FederatedKMeansPlanningTest.java | 156 ++++++++ .../FederatedL2SVMPlanningTest.java | 185 +++++++++ .../FederatedMultiplyPlanningTest.java | 318 +++++++++++++++ .../functions/fedplanning/FTypeCombTest.java | 71 ++++ .../FederatedCostEstimatorTest.java | 373 ++++++++++++++++++ .../FederatedDynamicPlanningTest.java | 188 +++++++++ .../FederatedKMeansPlanningTest.java | 168 ++++++++ .../FederatedL2SVMPlanningTest.java | 202 ++++++++++ .../FederatedMultiplyPlanningTest.java | 334 ++++++++++++++++ 12 files changed, 2174 insertions(+), 8 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java create mode 100644 src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java create mode 100644 src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java create mode 100644 src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java create mode 100644 src/test/java/org/apache/sysds/test/functions/fedplanning/FTypeCombTest.java create mode 100644 src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedCostEstimatorTest.java create mode 100644 src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedDynamicPlanningTest.java create mode 100644 src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedKMeansPlanningTest.java create mode 100644 src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedL2SVMPlanningTest.java create mode 100644 src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedMultiplyPlanningTest.java diff --git a/.github/workflows/javaTests.yml b/.github/workflows/javaTests.yml index cd6e28670e2..9f0258a0d15 100644 --- a/.github/workflows/javaTests.yml +++ b/.github/workflows/javaTests.yml @@ -62,7 +62,7 @@ jobs: "**.functions.compress.**,**.functions.data.tensor.**,**.functions.codegenalg.parttwo.**,**.functions.codegen.**,**.functions.caching.**", "**.functions.binary.matrix_full_cellwise.**,**.functions.binary.matrix_full_other.**", "**.functions.federated.algorithms.**,**.functions.federated.io.**,**.functions.federated.paramserv.**", - "**.functions.federated.transform.**", + "**.functions.federated.transform.**,**.functions.federated.fedplanner.**", "**.functions.federated.primitives.part1.** -Dtest-threadCount=1 -Dtest-forkCount=1", "**.functions.federated.primitives.part2.** -Dtest-threadCount=1 -Dtest-forkCount=1", "**.functions.federated.primitives.part3.** -Dtest-threadCount=1 -Dtest-forkCount=1", diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java b/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java index a82a56e88b9..de9c9cb670e 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java @@ -44,13 +44,8 @@ public boolean isCompiled() { return this != NONE && this != RUNTIME; } public static boolean isCompiled(String planner) { - try { - return FederatedPlanner.valueOf(planner.toUpperCase()).isCompiled(); - } - catch(Exception ex) { - ex.printStackTrace(); - return false; - } + return planner != null + && FederatedPlanner.valueOf(planner.toUpperCase()).isCompiled(); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java new file mode 100644 index 00000000000..bd098bf8271 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.federated.fedplanning; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Ignore; +import org.junit.Test; + +import java.io.File; +import java.util.Arrays; + +import static org.junit.Assert.fail; + +@net.jcip.annotations.NotThreadSafe +public class FederatedDynamicPlanningTest extends AutomatedTestBase { + private static final Log LOG = LogFactory.getLog(FederatedDynamicPlanningTest.class.getName()); + + private final static String TEST_DIR = "functions/privacy/fedplanning/"; + private final static String TEST_NAME = "FederatedDynamicFunctionPlanningTest"; + private final static String TEST_CLASS_DIR = TEST_DIR + FederatedDynamicPlanningTest.class.getSimpleName() + "/"; + private static File TEST_CONF_FILE; + + private final static int blocksize = 1024; + public final int rows = 1000; + public final int cols = 1000; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"})); + } + + @Test + @Ignore + public void runDynamicFullFunctionTest() { + // compared to `FederatedL2SVMPlanningTest` this does not create `fed_+*` or `fed_tsmm`, probably due to + // some rewrites not being applied. Might be a bug. + String[] expectedHeavyHitters = new String[] {"fed_fedinit", "fed_ba+*", "fed_tak+*", "fed_max", + "fed_1-*", "fed_>"}; + setTestConf("SystemDS-config-fout.xml"); + loadAndRunTest(expectedHeavyHitters, TEST_NAME); + } + + @Test + @Ignore + public void runDynamicHeuristicFunctionTest() { + // compared to `FederatedL2SVMPlanningTest` this does not create `fed_+*` or `fed_tsmm`, probably due to + // some rewrites not being applied. Might be a bug. + String[] expectedHeavyHitters = new String[] {"fed_fedinit", "fed_ba+*"}; + setTestConf("SystemDS-config-heuristic.xml"); + loadAndRunTest(expectedHeavyHitters, TEST_NAME); + } + + private void setTestConf(String test_conf) { + TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, test_conf); + } + + private void writeInputMatrices() { + writeBinaryVector("A", 42, rows); + writeStandardMatrix("B1", 65, rows / 2, cols); + writeStandardMatrix("B2", 75, rows / 2, cols); + writeStandardMatrix("C1", 13, rows, cols / 2); + writeStandardMatrix("C2", 17, rows, cols / 2); + } + + private void writeBinaryVector(String matrixName, long seed, int numRows){ + double[][] matrix = getRandomMatrix(numRows, 1, -1, 1, 1, seed); + for(int i = 0; i < numRows; i++) + matrix[i][0] = (matrix[i][0] > 0) ? 1 : -1; + MatrixCharacteristics mc = new MatrixCharacteristics(numRows, 1, blocksize, numRows); + writeInputMatrixWithMTD(matrixName, matrix, false, mc); + } + + private void writeStandardMatrix(String matrixName, long seed, int numRows, int numCols) { + double[][] matrix = getRandomMatrix(numRows, numCols, 0, 1, 1, seed); + writeStandardMatrix(matrixName, numRows, numCols, matrix); + } + + private void writeStandardMatrix(String matrixName, int numRows, int numCols, double[][] matrix) { + MatrixCharacteristics mc = new MatrixCharacteristics(numRows, numCols, blocksize, (long) numRows * numCols); + writeInputMatrixWithMTD(matrixName, matrix, false, mc); + } + + private void loadAndRunTest(String[] expectedHeavyHitters, String testName) { + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + Types.ExecMode platformOld = rtplatform; + rtplatform = Types.ExecMode.SINGLE_NODE; + + Thread t1 = null, t2 = null; + + try { + getAndLoadTestConfiguration(testName); + String HOME = SCRIPT_DIR + TEST_DIR; + + writeInputMatrices(); + + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); + t2 = startLocalFedWorkerThread(port2); + + // Run actual dml script with federated matrix + fullDMLScriptName = HOME + testName + ".dml"; + programArgs = new String[] {"-stats", "-nvargs", + "r=" + rows, "c=" + cols, + "A=" + input("A"), + "B1=" + TestUtils.federatedAddress(port1, input("B1")), + "B2=" + TestUtils.federatedAddress(port2, input("B2")), + "C1=" + TestUtils.federatedAddress(port1, input("C1")), + "C2=" + TestUtils.federatedAddress(port2, input("C2")), + "lB1=" + input("B1"), + "lB2=" + input("B2"), + "Z=" + output("Z")}; + runTest(true, false, null, -1); + + // Run reference dml script with normal matrix + fullDMLScriptName = HOME + testName + "Reference.dml"; + programArgs = new String[] {"-nvargs", + "r=" + rows, "c=" + cols, + "A=" + input("A"), + "B1=" + input("B1"), + "B2=" + input("B2"), + "C1=" + input("C1"), + "C2=" + input("C2"), + "Z=" + expected("Z")}; + runTest(true, false, null, -1); + + // compare via files + compareResults(1e-9); + if(!heavyHittersContainsAllString(expectedHeavyHitters)) + fail("The following expected heavy hitters are missing: " + + Arrays.toString(missingHeavyHitters(expectedHeavyHitters))); + } + finally { + TestUtils.shutdownThreads(t1, t2); + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } + + /** + * Override default configuration with custom test configuration to ensure scratch space and local temporary + * directory locations are also updated. + */ + @Override + protected File getConfigTemplateFile() { + // Instrumentation in this test's output log to show custom configuration file used for template. + LOG.info("This test case overrides default configuration with " + TEST_CONF_FILE.getPath()); + return TEST_CONF_FILE; + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java new file mode 100644 index 00000000000..326516d4234 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.federated.fedplanning; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import java.io.File; +import java.util.Arrays; + +import static org.junit.Assert.fail; + +public class FederatedKMeansPlanningTest extends AutomatedTestBase { + private static final Log LOG = LogFactory.getLog(FederatedKMeansPlanningTest.class.getName()); + + private final static String TEST_DIR = "functions/privacy/fedplanning/"; + private final static String TEST_NAME = "FederatedKMeansPlanningTest"; + private final static String TEST_CLASS_DIR = TEST_DIR + FederatedKMeansPlanningTest.class.getSimpleName() + "/"; + private static File TEST_CONF_FILE; + + private final static int blocksize = 1024; + public final int rows = 1000; + public final int cols = 100; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"})); + } + + @Test + public void runKMeansFOUTTest(){ + String[] expectedHeavyHitters = new String[]{}; + setTestConf("SystemDS-config-fout.xml"); + loadAndRunTest(expectedHeavyHitters, TEST_NAME); + } + + @Test + public void runKMeansHeuristicTest(){ + String[] expectedHeavyHitters = new String[]{}; + setTestConf("SystemDS-config-heuristic.xml"); + loadAndRunTest(expectedHeavyHitters, TEST_NAME); + } + + @Test + public void runRuntimeTest(){ + String[] expectedHeavyHitters = new String[]{}; + TEST_CONF_FILE = new File("src/test/config/SystemDS-config.xml"); + loadAndRunTest(expectedHeavyHitters, TEST_NAME); + } + + private void setTestConf(String test_conf){ + TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, test_conf); + } + + /** + * Override default configuration with custom test configuration to ensure + * scratch space and local temporary directory locations are also updated. + */ + @Override + protected File getConfigTemplateFile() { + // Instrumentation in this test's output log to show custom configuration file used for template. + LOG.info("This test case overrides default configuration with " + TEST_CONF_FILE.getPath()); + return TEST_CONF_FILE; + } + + private void writeInputMatrices(){ + writeStandardRowFedMatrix("X1", 65); + writeStandardRowFedMatrix("X2", 75); + } + + private void writeStandardMatrix(String matrixName, long seed, int numRows){ + double[][] matrix = getRandomMatrix(numRows, cols, 0, 1, 1, seed); + writeStandardMatrix(matrixName, numRows, matrix); + } + + private void writeStandardMatrix(String matrixName, int numRows, double[][] matrix){ + MatrixCharacteristics mc = new MatrixCharacteristics(numRows, cols, blocksize, (long) numRows * cols); + writeInputMatrixWithMTD(matrixName, matrix, false, mc); + } + + private void writeStandardRowFedMatrix(String matrixName, long seed){ + int halfRows = rows/2; + writeStandardMatrix(matrixName, seed, halfRows); + } + + private void loadAndRunTest(String[] expectedHeavyHitters, String testName){ + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + Types.ExecMode platformOld = rtplatform; + rtplatform = Types.ExecMode.SINGLE_NODE; + + Thread t1 = null, t2 = null; + + try { + getAndLoadTestConfiguration(testName); + String HOME = SCRIPT_DIR + TEST_DIR; + + writeInputMatrices(); + + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); + t2 = startLocalFedWorkerThread(port2); + + // Run actual dml script with federated matrix + fullDMLScriptName = HOME + testName + ".dml"; + programArgs = new String[] { "-stats", "-nvargs", + "X1=" + TestUtils.federatedAddress(port1, input("X1")), + "X2=" + TestUtils.federatedAddress(port2, input("X2")), + "Y=" + input("Y"), "r=" + rows, "c=" + cols, "Z=" + output("Z")}; + runTest(true, false, null, -1); + + // Run reference dml script with normal matrix + fullDMLScriptName = HOME + testName + "Reference.dml"; + programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), + "Y=" + input("Y"), "Z=" + expected("Z")}; + runTest(true, false, null, -1); + + // compare via files + compareResults(1e-9); + if (!heavyHittersContainsAllString(expectedHeavyHitters)) + fail("The following expected heavy hitters are missing: " + + Arrays.toString(missingHeavyHitters(expectedHeavyHitters))); + } + finally { + TestUtils.shutdownThreads(t1, t2); + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java new file mode 100644 index 00000000000..3e8f8719a65 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.federated.fedplanning; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Ignore; +import org.junit.Test; + +import java.io.File; +import java.util.Arrays; + +import static org.junit.Assert.fail; + +@net.jcip.annotations.NotThreadSafe +public class FederatedL2SVMPlanningTest extends AutomatedTestBase { + private static final Log LOG = LogFactory.getLog(FederatedL2SVMPlanningTest.class.getName()); + + private final static String TEST_DIR = "functions/privacy/fedplanning/"; + private final static String TEST_NAME = "FederatedL2SVMPlanningTest"; + private final static String TEST_NAME_2 = "FederatedL2SVMFunctionPlanningTest"; + private final static String TEST_CLASS_DIR = TEST_DIR + FederatedL2SVMPlanningTest.class.getSimpleName() + "/"; + private static File TEST_CONF_FILE; + + private final static int blocksize = 1024; + public final int rows = 1000; + public final int cols = 100; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"})); + addTestConfiguration(TEST_NAME_2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_2, new String[] {"Z"})); + } + + @Test + public void runL2SVMFOUTTest(){ + String[] expectedHeavyHitters = new String[]{ "fed_fedinit", "fed_ba+*", "fed_tak+*", "fed_+*", + "fed_max", "fed_1-*", "fed_tsmm", "fed_>"}; + setTestConf("SystemDS-config-fout.xml"); + loadAndRunTest(expectedHeavyHitters, TEST_NAME); + } + + @Test + public void runL2SVMHeuristicTest(){ + String[] expectedHeavyHitters = new String[]{ "fed_fedinit", "fed_ba+*"}; + setTestConf("SystemDS-config-heuristic.xml"); + loadAndRunTest(expectedHeavyHitters, TEST_NAME); + } + + @Test + @Ignore //TODO + public void runL2SVMFunctionFOUTTest(){ + String[] expectedHeavyHitters = new String[]{ "fed_fedinit", "fed_ba+*", "fed_tak+*", "fed_+*", + "fed_max", "fed_1-*", "fed_tsmm", "fed_>"}; + setTestConf("SystemDS-config-fout.xml"); + loadAndRunTest(expectedHeavyHitters, TEST_NAME_2); + } + + @Test + @Ignore //TODO + public void runL2SVMFunctionHeuristicTest(){ + String[] expectedHeavyHitters = new String[]{ "fed_fedinit", "fed_ba+*"}; + setTestConf("SystemDS-config-heuristic.xml"); + loadAndRunTest(expectedHeavyHitters, TEST_NAME_2); + } + + private void setTestConf(String test_conf){ + TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, test_conf); + } + + private void writeInputMatrices(){ + writeStandardRowFedMatrix("X1", 65); + writeStandardRowFedMatrix("X2", 75); + writeBinaryVector("Y", 44); + } + + private void writeBinaryVector(String matrixName, long seed){ + double[][] matrix = getRandomMatrix(rows, 1, -1, 1, 1, seed); + for(int i = 0; i < rows; i++) + matrix[i][0] = (matrix[i][0] > 0) ? 1 : -1; + MatrixCharacteristics mc = new MatrixCharacteristics(rows, 1, blocksize, rows); + writeInputMatrixWithMTD(matrixName, matrix, false, mc); + } + + @SuppressWarnings("unused") + private void writeStandardMatrix(String matrixName, long seed){ + writeStandardMatrix(matrixName, seed, rows); + } + + private void writeStandardMatrix(String matrixName, long seed, int numRows){ + double[][] matrix = getRandomMatrix(numRows, cols, 0, 1, 1, seed); + writeStandardMatrix(matrixName, numRows, matrix); + } + + private void writeStandardMatrix(String matrixName, int numRows, double[][] matrix){ + MatrixCharacteristics mc = new MatrixCharacteristics(numRows, cols, blocksize, (long) numRows * cols); + writeInputMatrixWithMTD(matrixName, matrix, false, mc); + } + + private void writeStandardRowFedMatrix(String matrixName, long seed){ + int halfRows = rows/2; + writeStandardMatrix(matrixName, seed, halfRows); + } + + private void loadAndRunTest(String[] expectedHeavyHitters, String testName){ + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + Types.ExecMode platformOld = rtplatform; + rtplatform = Types.ExecMode.SINGLE_NODE; + + Thread t1 = null, t2 = null; + + try { + getAndLoadTestConfiguration(testName); + String HOME = SCRIPT_DIR + TEST_DIR; + + writeInputMatrices(); + + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); + t2 = startLocalFedWorkerThread(port2); + + // Run actual dml script with federated matrix + fullDMLScriptName = HOME + testName + ".dml"; + programArgs = new String[] { "-stats", "-nvargs", + "X1=" + TestUtils.federatedAddress(port1, input("X1")), + "X2=" + TestUtils.federatedAddress(port2, input("X2")), + "Y=" + input("Y"), "r=" + rows, "c=" + cols, "Z=" + output("Z")}; + runTest(true, false, null, -1); + + // Run reference dml script with normal matrix + fullDMLScriptName = HOME + testName + "Reference.dml"; + programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), + "Y=" + input("Y"), "Z=" + expected("Z")}; + runTest(true, false, null, -1); + + // compare via files + compareResults(1e-9); + if (!heavyHittersContainsAllString(expectedHeavyHitters)) + fail("The following expected heavy hitters are missing: " + + Arrays.toString(missingHeavyHitters(expectedHeavyHitters))); + } + finally { + TestUtils.shutdownThreads(t1, t2); + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } + + /** + * Override default configuration with custom test configuration to ensure + * scratch space and local temporary directory locations are also updated. + */ + @Override + protected File getConfigTemplateFile() { + // Instrumentation in this test's output log to show custom configuration file used for template. + LOG.info("This test case overrides default configuration with " + TEST_CONF_FILE.getPath()); + return TEST_CONF_FILE; + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java new file mode 100644 index 00000000000..5b54f14d059 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.federated.fedplanning; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; + +import java.io.File; +import java.util.Arrays; +import java.util.Collection; + +import static org.junit.Assert.fail; + +@RunWith(value = Parameterized.class) +@net.jcip.annotations.NotThreadSafe +public class FederatedMultiplyPlanningTest extends AutomatedTestBase { + private static final Log LOG = LogFactory.getLog(FederatedMultiplyPlanningTest.class.getName()); + + private final static String TEST_DIR = "functions/privacy/fedplanning/"; + private final static String TEST_NAME = "FederatedMultiplyPlanningTest"; + private final static String TEST_NAME_2 = "FederatedMultiplyPlanningTest2"; + private final static String TEST_NAME_3 = "FederatedMultiplyPlanningTest3"; + private final static String TEST_NAME_4 = "FederatedMultiplyPlanningTest4"; + private final static String TEST_NAME_5 = "FederatedMultiplyPlanningTest5"; + private final static String TEST_NAME_6 = "FederatedMultiplyPlanningTest6"; + private final static String TEST_NAME_7 = "FederatedMultiplyPlanningTest7"; + private final static String TEST_NAME_8 = "FederatedMultiplyPlanningTest8"; + private final static String TEST_NAME_9 = "FederatedMultiplyPlanningTest9"; + private final static String TEST_NAME_10 = "FederatedMultiplyPlanningTest10"; + private final static String TEST_NAME_11 = "FederatedMultiplyPlanningTest11"; + private final static String TEST_NAME_12 = "FederatedMultiplyPlanningTest12"; + private final static String TEST_CLASS_DIR = TEST_DIR + FederatedMultiplyPlanningTest.class.getSimpleName() + "/"; + private static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, "SystemDS-config-heuristic.xml"); + + private final static int blocksize = 1024; + @Parameterized.Parameter() + public int rows; + @Parameterized.Parameter(1) + public int cols; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"})); + addTestConfiguration(TEST_NAME_2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_2, new String[] {"Z"})); + addTestConfiguration(TEST_NAME_3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_3, new String[] {"Z.scalar"})); + addTestConfiguration(TEST_NAME_4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_4, new String[] {"Z"})); + addTestConfiguration(TEST_NAME_5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_5, new String[] {"Z"})); + addTestConfiguration(TEST_NAME_6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_6, new String[] {"Z"})); + addTestConfiguration(TEST_NAME_7, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_7, new String[] {"Z"})); + addTestConfiguration(TEST_NAME_8, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_8, new String[] {"Z.scalar"})); + addTestConfiguration(TEST_NAME_9, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_9, new String[] {"Z.scalar"})); + addTestConfiguration(TEST_NAME_10, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_10, new String[] {"Z"})); + addTestConfiguration(TEST_NAME_11, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_11, new String[] {"Z"})); + addTestConfiguration(TEST_NAME_12, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_12, new String[] {"Z"})); + } + + @Parameterized.Parameters + public static Collection data() { + // rows have to be even and > 1 + return Arrays.asList(new Object[][] { + {100, 10} + }); + } + + @Test + public void federatedMultiplyCP() { + String[] expectedHeavyHitters = new String[]{"fed_*", "fed_fedinit", "fed_r'", "fed_ba+*"}; + federatedTwoMatricesSingleNodeTest(TEST_NAME, expectedHeavyHitters); + } + + @Test + public void federatedRowSum(){ + String[] expectedHeavyHitters = new String[]{"fed_*", "fed_r'", "fed_fedinit", "fed_ba+*", "fed_uark+"}; + federatedTwoMatricesSingleNodeTest(TEST_NAME_2, expectedHeavyHitters); + } + + @Test + public void federatedTernarySequence(){ + String[] expectedHeavyHitters = new String[]{"fed_+*", "fed_1-*", "fed_fedinit", "fed_uak+"}; + federatedTwoMatricesSingleNodeTest(TEST_NAME_3, expectedHeavyHitters); + } + + @Test + public void federatedAggregateBinarySequence(){ + cols = rows; + String[] expectedHeavyHitters = new String[]{"fed_ba+*", "fed_*", "fed_fedinit"}; + federatedTwoMatricesSingleNodeTest(TEST_NAME_4, expectedHeavyHitters); + } + + @Test + public void federatedAggregateBinaryColFedSequence(){ + cols = rows; + //TODO: When alignment checks have been added to getFederatedOut in AFederatedPlanner, + // the following expectedHeavyHitters can be added. Until then, fed_* will not be generated. + //String[] expectedHeavyHitters = new String[]{"fed_ba+*","fed_*","fed_fedinit"}; + String[] expectedHeavyHitters = new String[]{"fed_ba+*","fed_fedinit"}; + federatedTwoMatricesSingleNodeTest(TEST_NAME_5, expectedHeavyHitters); + } + + @Test + public void federatedAggregateBinarySequence2(){ + String[] expectedHeavyHitters = new String[]{"fed_ba+*","fed_fedinit"}; + federatedTwoMatricesSingleNodeTest(TEST_NAME_6, expectedHeavyHitters); + } + + @Test + public void federatedMultiplyDoubleHop() { + String[] expectedHeavyHitters = new String[]{"fed_*", "fed_fedinit", "fed_ba+*"}; //TODO "fed_r' " ? + federatedTwoMatricesSingleNodeTest(TEST_NAME_7, expectedHeavyHitters); + } + + @Test + public void federatedMultiplyDoubleHop2() { + String[] expectedHeavyHitters = new String[]{"fed_fedinit", "fed_ba+*"}; + federatedTwoMatricesSingleNodeTest(TEST_NAME_8, expectedHeavyHitters); + } + + @Test + public void federatedMultiplyPlanningTest9(){ + String[] expectedHeavyHitters = new String[]{"fed_+*", "fed_1-*", "fed_fedinit", "fed_max"}; //TODO "fed_tak+*" + federatedTwoMatricesSingleNodeTest(TEST_NAME_9, expectedHeavyHitters); + } + + @Test + public void federatedMultiplyPlanningTest10(){ + String[] expectedHeavyHitters = new String[]{"fed_fedinit", "fed_^2"}; + TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, "SystemDS-config-fout.xml"); + federatedTwoMatricesSingleNodeTest(TEST_NAME_10, expectedHeavyHitters); + } + + @Test + public void federatedMultiplyPlanningTest11(){ + String[] expectedHeavyHitters = new String[]{"fed_fedinit"}; + federatedTwoMatricesSingleNodeTest(TEST_NAME_11, expectedHeavyHitters); + } + + @Test + public void federatedMultiplyPlanningTest12(){ + String[] expectedHeavyHitters = new String[]{"fed_fedinit"}; + rows = 30; + cols = 30; + federatedTwoMatricesSingleNodeTest(TEST_NAME_12, expectedHeavyHitters); + } + + private void writeStandardMatrix(String matrixName, long seed){ + int halfRows = rows/2; + double[][] matrix = getRandomMatrix(halfRows, cols, 0, 1, 1, seed); + MatrixCharacteristics mc = new MatrixCharacteristics(halfRows, cols, blocksize, (long) halfRows * cols); + writeInputMatrixWithMTD(matrixName, matrix, false, mc); + } + + private void writeColStandardMatrix(String matrixName, long seed){ + int halfCols = cols/2; + double[][] matrix = getRandomMatrix(rows, halfCols, 0, 1, 1, seed); + MatrixCharacteristics mc = new MatrixCharacteristics(rows, halfCols, blocksize, (long) halfCols *rows); + writeInputMatrixWithMTD(matrixName, matrix, false, mc); + } + + private void writeRowFederatedVector(String matrixName, long seed){ + int halfCols = cols / 2; + double[][] matrix = getRandomMatrix(halfCols, 1, 0, 1, 1, seed); + MatrixCharacteristics mc = new MatrixCharacteristics(halfCols, 1, blocksize, (long) halfCols *rows); + writeInputMatrixWithMTD(matrixName, matrix, false, mc); + } + + private void writeInputMatrices(String testName){ + if ( testName.equals(TEST_NAME_5) ){ + writeColStandardMatrix("X1", 42); + writeColStandardMatrix("X2", 1340); + writeColStandardMatrix("Y1", 44); + writeColStandardMatrix("Y2", 21); + } + else if ( testName.equals(TEST_NAME_6) ){ + writeColStandardMatrix("X1", 42); + writeColStandardMatrix("X2", 1340); + writeRowFederatedVector("Y1", 44); + writeRowFederatedVector("Y2", 21); + } + else if ( testName.equals(TEST_NAME_8) ){ + writeColStandardMatrix("X1", 42); + writeColStandardMatrix("X2", 1340); + writeColStandardMatrix("Y1", 44); + writeColStandardMatrix("Y2", 21); + writeColStandardMatrix("W1", 76); + writeColStandardMatrix("W2", 11); + } + else if ( testName.equals(TEST_NAME_10) || testName.equals(TEST_NAME_12) ){ + writeStandardMatrix("X1", 42); + writeStandardMatrix("X2", 1340); + } + else { + writeStandardMatrix("X1", 42); + writeStandardMatrix("X2", 1340); + if ( testName.equals(TEST_NAME_4) ){ + writeStandardMatrix("Y1", 44); + writeStandardMatrix("Y2", 21); + } + else { + writeStandardMatrix("Y1", 44); + writeStandardMatrix("Y2", 21); + } + } + } + + private void federatedTwoMatricesSingleNodeTest(String testName, String[] expectedHeavyHitters){ + federatedTwoMatricesTest(Types.ExecMode.SINGLE_NODE, testName, expectedHeavyHitters); + } + + private void federatedTwoMatricesTest(Types.ExecMode execMode, String testName, String[] expectedHeavyHitters) { + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + Types.ExecMode platformOld = rtplatform; + rtplatform = execMode; + if(rtplatform == Types.ExecMode.SPARK) { + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + Thread t1 = null, t2 = null; + + try{ + getAndLoadTestConfiguration(testName); + String HOME = SCRIPT_DIR + TEST_DIR; + + writeInputMatrices(testName); + + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); + t2 = startLocalFedWorkerThread(port2); + + // Run actual dml script with federated matrix + fullDMLScriptName = HOME + testName + ".dml"; + programArgs = new String[] {"-stats", "-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")), + "X2=" + TestUtils.federatedAddress(port2, input("X2")), + "Y1=" + TestUtils.federatedAddress(port1, input("Y1")), + "Y2=" + TestUtils.federatedAddress(port2, input("Y2")), "r=" + rows, "c=" + cols, "Z=" + output("Z")}; + rewriteRealProgramArgs(testName, port1, port2); + runTest(true, false, null, -1); + + // Run reference dml script with normal matrix + fullDMLScriptName = HOME + testName + "Reference.dml"; + programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"), + "Y2=" + input("Y2"), "Z=" + expected("Z")}; + rewriteReferenceProgramArgs(testName); + runTest(true, false, null, -1); + + // compare via files + compareResults(1e-9); + if (!heavyHittersContainsAllString(expectedHeavyHitters)) + fail("The following expected heavy hitters are missing: " + + Arrays.toString(missingHeavyHitters(expectedHeavyHitters))); + } finally { + TestUtils.shutdownThreads(t1, t2); + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } + + private void rewriteRealProgramArgs(String testName, int port1, int port2){ + if ( testName.equals(TEST_NAME_4) || testName.equals(TEST_NAME_5) ){ + programArgs = new String[] {"-stats","-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")), + "X2=" + TestUtils.federatedAddress(port2, input("X2")), + "Y1=" + input("Y1"), + "Y2=" + input("Y2"), "r=" + rows, "c=" + cols, "Z=" + output("Z")}; + } else if ( testName.equals(TEST_NAME_8) ){ + programArgs = new String[] {"-stats","-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")), + "X2=" + TestUtils.federatedAddress(port2, input("X2")), + "Y1=" + TestUtils.federatedAddress(port1, input("Y1")), + "Y2=" + TestUtils.federatedAddress(port2, input("Y2")), + "W1=" + input("W1"), + "W2=" + input("W2"), + "r=" + rows, "c=" + cols, "Z=" + output("Z")}; + } + } + + private void rewriteReferenceProgramArgs(String testName){ + if ( testName.equals(TEST_NAME_8) ){ + programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"), + "Y2=" + input("Y2"), "W1=" + input("W1"), "W2=" + input("W2"), "Z=" + expected("Z")}; + } + } + + /** + * Override default configuration with custom test configuration to ensure + * scratch space and local temporary directory locations are also updated. + */ + @Override + protected File getConfigTemplateFile() { + // Instrumentation in this test's output log to show custom configuration file used for template. + LOG.info("This test case overrides default configuration with " + TEST_CONF_FILE.getPath()); + return TEST_CONF_FILE; + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/fedplanning/FTypeCombTest.java b/src/test/java/org/apache/sysds/test/functions/fedplanning/FTypeCombTest.java new file mode 100644 index 00000000000..e36d517d989 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/fedplanning/FTypeCombTest.java @@ -0,0 +1,71 @@ +package org.apache.sysds.test.functions.fedplanning; +///* +// * Licensed to the Apache Software Foundation (ASF) under one +// * or more contributor license agreements. See the NOTICE file +// * distributed with this work for additional information +// * regarding copyright ownership. The ASF licenses this file +// * to you under the Apache License, Version 2.0 (the +// * "License"); you may not use this file except in compliance +// * with the License. You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, +// * software distributed under the License is distributed on an +// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// * KIND, either express or implied. See the License for the +// * specific language governing permissions and limitations +// * under the License. +// */ +// +//package org.apache.sysds.test.functions.privacy.fedplanning; +// +//import org.apache.sysds.hops.fedplanner.FTypes.FType; +//import org.apache.sysds.hops.fedplanner.FederatedPlannerCostbased; +//import org.apache.sysds.test.AutomatedTestBase; +//import org.junit.Assert; +//import org.junit.Test; +// +//import java.util.ArrayList; +//import java.util.List; +// +//public class FTypeCombTest extends AutomatedTestBase { +// +// @Override public void setUp() {} +// +// @Test +// public void ftypeCombTest(){ +// List secondInput = new ArrayList<>(); +// secondInput.add(null); +// List> inputFTypes = List.of( +// List.of(FType.ROW,FType.COL), +// secondInput, +// List.of(FType.BROADCAST,FType.FULL) +// ); +// +// FederatedPlannerCostbased planner = new FederatedPlannerCostbased(); +// List> actualCombinations = planner.getAllCombinations(inputFTypes); +// +// List expected1 = new ArrayList<>(); +// expected1.add(FType.ROW); +// expected1.add(null); +// expected1.add(FType.BROADCAST); +// List expected2 = new ArrayList<>(); +// expected2.add(FType.ROW); +// expected2.add(null); +// expected2.add(FType.FULL); +// List expected3 = new ArrayList<>(); +// expected3.add(FType.COL); +// expected3.add(null); +// expected3.add(FType.BROADCAST); +// List expected4 = new ArrayList<>(); +// expected4.add(FType.COL); +// expected4.add(null); +// expected4.add(FType.FULL); +// List> expectedCombinations = List.of(expected1,expected2, expected3, expected4); +// +// Assert.assertEquals(expectedCombinations.size(), actualCombinations.size()); +// for (List expectedComb : expectedCombinations) +// Assert.assertTrue(actualCombinations.contains(expectedComb)); +// } +//} diff --git a/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedCostEstimatorTest.java b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedCostEstimatorTest.java new file mode 100644 index 00000000000..073c8f1d9d5 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedCostEstimatorTest.java @@ -0,0 +1,373 @@ +package org.apache.sysds.test.functions.fedplanning; +///* +// * Licensed to the Apache Software Foundation (ASF) under one +// * or more contributor license agreements. See the NOTICE file +// * distributed with this work for additional information +// * regarding copyright ownership. The ASF licenses this file +// * to you under the Apache License, Version 2.0 (the +// * "License"); you may not use this file except in compliance +// * with the License. You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, +// * software distributed under the License is distributed on an +// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// * KIND, either express or implied. See the License for the +// * specific language governing permissions and limitations +// * under the License. +// */ +// +//package org.apache.sysds.test.functions.privacy.fedplanning; +// +//import net.jcip.annotations.NotThreadSafe; +//import org.apache.sysds.api.DMLScript; +//import org.apache.sysds.common.Types; +//import org.apache.sysds.conf.ConfigurationManager; +//import org.apache.sysds.conf.DMLConfig; +//import org.apache.sysds.hops.AggBinaryOp; +//import org.apache.sysds.hops.BinaryOp; +//import org.apache.sysds.hops.DataOp; +//import org.apache.sysds.hops.Hop; +//import org.apache.sysds.hops.LiteralOp; +//import org.apache.sysds.hops.NaryOp; +//import org.apache.sysds.hops.ReorgOp; +//import org.apache.sysds.hops.cost.FederatedCost; +//import org.apache.sysds.hops.cost.FederatedCostEstimator; +//import org.apache.sysds.hops.fedplanner.FederatedPlannerCostbased; +//import org.apache.sysds.hops.ipa.FunctionCallGraph; +//import org.apache.sysds.parser.DMLProgram; +//import org.apache.sysds.parser.DMLTranslator; +//import org.apache.sysds.parser.LanguageException; +//import org.apache.sysds.parser.ParserFactory; +//import org.apache.sysds.parser.ParserWrapper; +//import org.apache.sysds.parser.StatementBlock; +//import org.apache.sysds.runtime.instructions.fed.FEDInstruction; +//import org.apache.sysds.test.AutomatedTestBase; +//import org.apache.sysds.test.TestConfiguration; +//import org.junit.After; +//import org.junit.Assert; +//import org.junit.Before; +//import org.junit.BeforeClass; +//import org.junit.Test; +// +//import java.io.FileNotFoundException; +//import java.io.IOException; +//import java.util.HashMap; +//import java.util.HashSet; +//import java.util.Set; +// +//import static org.apache.sysds.common.Types.OpOp2.MULT; +// +//@NotThreadSafe +//public class FederatedCostEstimatorTest extends AutomatedTestBase { +// +// private static final String TEST_DIR = "functions/privacy/fedplanning/"; +// private static final String HOME = SCRIPT_DIR + TEST_DIR; +// private static final String TEST_CLASS_DIR = TEST_DIR + FederatedCostEstimatorTest.class.getSimpleName() + "/"; +// FederatedCostEstimator fedCostEstimator = new FederatedCostEstimator(); +// +// private static double COMPUTE_FLOPS; +// private static double READ_PS; +// private static double NETWORK_PS; +// +// @Override +// public void setUp() {} +// +// @BeforeClass +// public static void storeConstants(){ +// COMPUTE_FLOPS = FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS; +// READ_PS = FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS; +// NETWORK_PS = FederatedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS; +// } +// +// @Before +// public void setConstants(){ +// FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2; +// FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10; +// FederatedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS = 5; +// } +// +// @After +// public void resetConstants(){ +// FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = COMPUTE_FLOPS; +// FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = READ_PS; +// FederatedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS = NETWORK_PS; +// } +// +// @Test +// public void simpleBinary() { +// +// /* +// * HOP Occurences ComputeCost ReadCost ComputeCostFinal ReadCostFinal +// * ------------------------------------------------------------------------------------------ +// * LiteralOp 16 1 0 0.0625 0 +// * DataGenOp 2 100 64 6.25 6.4 +// * BinaryOp 1 100 1600 6.25 160 +// * TOSTRING 1 1 800 0.0625 80 +// * UnaryOp 1 1 8 0.0625 0.8 +// */ +// double computeCost = (16+2*100+100+1+1) / (FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM); +// double readCost = (2*64+1600+800+8) / (FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS); +// +// double expectedCost = computeCost + readCost; +// runTest("BinaryCostEstimatorTest.dml", false, expectedCost); +// } +// +// @Test +// public void simpleBinaryHopRelTest() { +// runHopRelTest("BinaryCostEstimatorTest.dml", false); +// } +// +// @Test +// public void ifElseTest(){ +// double computeCost = (16+2*100+100+1+1) / (FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM); +// double readCost = (2*64+1600+800+8) / (FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS); +// double expectedCost = ((computeCost + readCost + 0.8 + 0.0625 + 0.0625) / 2) + 0.0625 + 0.8 + 0.0625; +// runTest("IfElseCostEstimatorTest.dml", false, expectedCost); +// } +// +// @Test +// public void ifElseHopRelTest(){ +// runHopRelTest("IfElseCostEstimatorTest.dml", false); +// } +// +// @Test +// public void whileTest(){ +// double computeCost = (16+2*100+100+1+1) / (FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM); +// double readCost = (2*64+1600+800+8) / (FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS); +// double expectedCost = (computeCost + readCost + 0.0625 + 0.0625 + 0.8) * StatementBlock.DEFAULT_LOOP_REPETITIONS; +// runTest("WhileCostEstimatorTest.dml", false, expectedCost); +// } +// +// @Test +// public void whileHopRelTest(){ +// runHopRelTest("WhileCostEstimatorTest.dml", false); +// } +// +// @Test +// public void forLoopTest(){ +// double computeCost = (16+2*100+100+1+1) / (FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM); +// double readCost = (2*64+1600+800+8) / (FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS); +// double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625; +// double expectedCost = (computeCost + readCost + predicateCost) * 5; +// runTest("ForLoopCostEstimatorTest.dml", false, expectedCost); +// } +// +// @Test +// public void forLoopHopRelTest(){ +// runHopRelTest("ForLoopCostEstimatorTest.dml", false); +// } +// +// @Test +// public void parForLoopTest(){ +// double computeCost = (16+2*100+100+1+1) / (FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM); +// double readCost = (2*64+1600+800+8) / (FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS); +// double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 0.0625; +// double expectedCost = (computeCost + readCost + predicateCost) * 5; +// runTest("ParForLoopCostEstimatorTest.dml", false, expectedCost); +// } +// +// @Test +// public void parForLoopHopRelTest(){ +// runHopRelTest("ParForLoopCostEstimatorTest.dml", false); +// } +// +// @Test +// public void functionTest(){ +// double computeCost = (16+2*100+100+1+1) / (FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM); +// double readCost = (2*64+1600+800+8) / (FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS); +// double expectedCost = (computeCost + readCost); +// runTest("FunctionCostEstimatorTest.dml", false, expectedCost); +// } +// +// @Test +// public void functionHopRelTest(){ +// runHopRelTest("FunctionCostEstimatorTest.dml", false); +// } +// +// @Test +// public void federatedMultiply() { +// +// double literalOpCost = 10*0.0625; +// double naryOpCostSpecial = (0.125+2.2); +// double naryOpCostSpecial2 = (0.25+6.4); +// double naryOpCost = 4*(0.125+1.6); +// double reorgOpCost = 6250+80015.2+160030.4; +// double binaryOpMultCost = 3125+160000; +// double aggBinaryOpCost = 125000+160015.2+160030.4+190.4; +// double dataOpCost = 2*(6250+5.6); +// double dataOpWriteCost = 6.25+100.3; +// +// double expectedCost = literalOpCost + naryOpCost + naryOpCostSpecial + naryOpCostSpecial2 + reorgOpCost +// + binaryOpMultCost + aggBinaryOpCost + dataOpCost + dataOpWriteCost; +// runTest("FederatedMultiplyCostEstimatorTest.dml", false, expectedCost); +// +// double aggBinaryActualCost = hops.stream() +// .filter(hop -> hop instanceof AggBinaryOp) +// .mapToDouble(aggHop -> aggHop.getFederatedCost().getTotal()-aggHop.getFederatedCost().getInputTotalCost()) +// .sum(); +// Assert.assertEquals(aggBinaryOpCost, aggBinaryActualCost, 0.0001); +// +// double writeActualCost = hops.stream() +// .filter(hop -> hop instanceof DataOp) +// .mapToDouble(writeHop -> writeHop.getFederatedCost().getTotal()-writeHop.getFederatedCost().getInputTotalCost()) +// .sum(); +// Assert.assertEquals(dataOpWriteCost+dataOpCost, writeActualCost, 0.0001); +// } +// +// Set hops = new HashSet<>(); +// +// /** +// * Recursively adds the hop and its inputs to the set of hops. +// * @param hop root to be added to set of hops +// */ +// private void addHop(Hop hop){ +// hops.add(hop); +// for(Hop inHop : hop.getInput()){ +// addHop(inHop); +// } +// } +// +// /** +// * Sets dimensions of federated X and Y and sets binary multiplication to FOUT. +// * @param prog dml program where the HOPS are modified +// */ +// private void modifyFedouts(DMLProgram prog){ +// prog.getStatementBlocks().forEach(stmBlock -> stmBlock.getHops().forEach(this::addHop)); +// hops.forEach(hop -> { +// if ( hop instanceof DataOp || (hop instanceof BinaryOp && ((BinaryOp) hop).getOp() == MULT ) ){ +// hop.setFederatedOutput(FEDInstruction.FederatedOutput.FOUT); +// hop.setExecType(Types.ExecType.FED); +// } else { +// hop.setFederatedOutput(FEDInstruction.FederatedOutput.LOUT); +// } +// if ( hop.getOpString().equals("Fed Y") || hop.getOpString().equals("Fed X") ){ +// hop.setDim1(10000); +// hop.setDim2(10); +// } +// }); +// } +// +// @SuppressWarnings("unused") +// private void printHopsInfo(){ +// //LiteralOp +// long literalCount = hops.stream().filter(hop -> hop instanceof LiteralOp).count(); +// System.out.println("LiteralOp Count: " + literalCount); +// //NaryOp +// long naryCount = hops.stream().filter(hop -> hop instanceof NaryOp).count(); +// System.out.println("NaryOp Count " + naryCount); +// //ReorgOp +// long reorgCount = hops.stream().filter(hop -> hop instanceof ReorgOp).count(); +// System.out.println("ReorgOp Count: " + reorgCount); +// //BinaryOp +// long binaryCount = hops.stream().filter(hop -> hop instanceof BinaryOp).count(); +// System.out.println("Binary count: " + binaryCount); +// //AggBinaryOp +// long aggBinaryCount = hops.stream().filter(hop -> hop instanceof AggBinaryOp).count(); +// System.out.println("AggBinaryOp Count: " + aggBinaryCount); +// //DataOp +// long dataOpCount = hops.stream().filter(hop -> hop instanceof DataOp).count(); +// System.out.println("DataOp Count: " + dataOpCount); +// +// hops.stream().map(Hop::getClass).distinct().forEach(System.out::println); +// } +// +// private DMLProgram testSetup(String scriptFilename) throws IOException{ +// setTestConfig(scriptFilename); +// String dmlScriptString = readScript(scriptFilename); +// +// //parsing, dependency analysis and constructing hops (step 3 and 4 in DMLScript.java) +// ParserWrapper parser = ParserFactory.createParser(); +// DMLProgram prog = parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new HashMap<>()); +// DMLTranslator dmlt = new DMLTranslator(prog); +// dmlt.liveVariableAnalysis(prog); +// dmlt.validateParseTree(prog); +// dmlt.constructHops(prog); +// if ( scriptFilename.equals("FederatedMultiplyCostEstimatorTest.dml")){ +// modifyFedouts(prog); +// dmlt.rewriteHopsDAG(prog); +// hops = new HashSet<>(); +// prog.getStatementBlocks().forEach(stmBlock -> stmBlock.getHops().forEach(this::addHop)); +// } +// return prog; +// } +// +// private void compareResults(DMLProgram prog) { +// FederatedPlannerCostbased rewriter = new FederatedPlannerCostbased(); +// rewriter.rewriteProgram(prog, new FunctionCallGraph(prog), null); +// +// double actualCost = 0; +// for ( Hop root : rewriter.getTerminalHops() ){ +// actualCost += root.getFederatedCost().getTotal(); +// } +// +// +// rewriter.getTerminalHops().forEach(Hop::resetFederatedCost); +// fedCostEstimator = new FederatedCostEstimator(); +// double expectedCost = 0; +// for ( Hop root : rewriter.getTerminalHops() ) +// expectedCost += fedCostEstimator.costEstimate(root).getTotal(); +// Assert.assertEquals(expectedCost, actualCost, 0.0001); +// } +// +// private void runHopRelTest( String scriptFilename, boolean expectedException ) { +// boolean raisedException = false; +// try +// { +// DMLProgram prog = testSetup(scriptFilename); +// compareResults(prog); +// } +// catch(LanguageException ex) { +// raisedException = true; +// if(raisedException!=expectedException) +// ex.printStackTrace(); +// } +// catch(Exception ex2) { +// ex2.printStackTrace(); +// throw new RuntimeException(ex2); +// } +// +// Assert.assertEquals("Expected exception does not match raised exception", +// expectedException, raisedException); +// } +// +// private void runTest( String scriptFilename, boolean expectedException, double expectedCost ) { +// boolean raisedException = false; +// try +// { +// DMLProgram prog = testSetup(scriptFilename); +// +// fedCostEstimator = new FederatedCostEstimator(); +// FederatedCost actualCost = fedCostEstimator.costEstimate(prog); +// Assert.assertEquals(expectedCost, actualCost.getTotal(), 0.0001); +// } +// catch(LanguageException ex) { +// raisedException = true; +// if(raisedException!=expectedException) +// ex.printStackTrace(); +// } +// catch(Exception ex2) { +// ex2.printStackTrace(); +// throw new RuntimeException(ex2); +// } +// +// Assert.assertEquals("Expected exception does not match raised exception", +// expectedException, raisedException); +// } +// +// private void setTestConfig(String scriptFilename) throws FileNotFoundException { +// int index = scriptFilename.lastIndexOf(".dml"); +// String testName = scriptFilename.substring(0, index > 0 ? index : scriptFilename.length()); +// TestConfiguration testConfig = new TestConfiguration(TEST_CLASS_DIR, testName, new String[] {}); +// addTestConfiguration(testName, testConfig); +// loadTestConfiguration(testConfig); +// +// DMLConfig conf = new DMLConfig(getCurConfigFile().getPath()); +// ConfigurationManager.setLocalConfig(conf); +// } +// +// private static String readScript(String scriptFilename) throws IOException { +// return DMLScript.readDMLScript(true, HOME + scriptFilename); +// } +//} diff --git a/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedDynamicPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedDynamicPlanningTest.java new file mode 100644 index 00000000000..23da01d4386 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedDynamicPlanningTest.java @@ -0,0 +1,188 @@ +package org.apache.sysds.test.functions.fedplanning; +///* +// * Licensed to the Apache Software Foundation (ASF) under one +// * or more contributor license agreements. See the NOTICE file +// * distributed with this work for additional information +// * regarding copyright ownership. The ASF licenses this file +// * to you under the Apache License, Version 2.0 (the +// * "License"); you may not use this file except in compliance +// * with the License. You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, +// * software distributed under the License is distributed on an +// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// * KIND, either express or implied. See the License for the +// * specific language governing permissions and limitations +// * under the License. +// */ +// +//package org.apache.sysds.test.functions.privacy.fedplanning; +// +//import org.apache.commons.logging.Log; +//import org.apache.commons.logging.LogFactory; +//import org.apache.sysds.api.DMLScript; +//import org.apache.sysds.common.Types; +//import org.apache.sysds.runtime.meta.MatrixCharacteristics; +//import org.apache.sysds.runtime.privacy.PrivacyConstraint; +//import org.apache.sysds.test.AutomatedTestBase; +//import org.apache.sysds.test.TestConfiguration; +//import org.apache.sysds.test.TestUtils; +//import org.junit.Test; +// +//import java.io.File; +//import java.util.Arrays; +// +//import static org.junit.Assert.fail; +// +//@net.jcip.annotations.NotThreadSafe +//public class FederatedDynamicPlanningTest extends AutomatedTestBase { +// private static final Log LOG = LogFactory.getLog(FederatedDynamicPlanningTest.class.getName()); +// +// private final static String TEST_DIR = "functions/privacy/fedplanning/"; +// private final static String TEST_NAME = "FederatedDynamicFunctionPlanningTest"; +// private final static String TEST_CLASS_DIR = TEST_DIR + FederatedDynamicPlanningTest.class.getSimpleName() + "/"; +// private static File TEST_CONF_FILE; +// +// private final static int blocksize = 1024; +// public final int rows = 1000; +// public final int cols = 1000; +// +// @Override +// public void setUp() { +// TestUtils.clearAssertionInformation(); +// addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"})); +// } +// +// @Test +// public void runDynamicFullFunctionTest() { +// // compared to `FederatedL2SVMPlanningTest` this does not create `fed_+*` or `fed_tsmm`, probably due to +// // some rewrites not being applied. Might be a bug. +// String[] expectedHeavyHitters = new String[] {"fed_fedinit", "fed_ba+*", "fed_tak+*", "fed_max", +// "fed_1-*", "fed_>"}; +// setTestConf("SystemDS-config-fout.xml"); +// loadAndRunTest(expectedHeavyHitters, TEST_NAME); +// } +// +// @Test +// public void runDynamicHeuristicFunctionTest() { +// // compared to `FederatedL2SVMPlanningTest` this does not create `fed_+*` or `fed_tsmm`, probably due to +// // some rewrites not being applied. Might be a bug. +// String[] expectedHeavyHitters = new String[] {"fed_fedinit", "fed_ba+*"}; +// setTestConf("SystemDS-config-heuristic.xml"); +// loadAndRunTest(expectedHeavyHitters, TEST_NAME); +// } +// +// @Test +// public void runDynamicCostBasedFunctionTest() { +// // compared to `FederatedL2SVMPlanningTest` this does not create `fed_+*` or `fed_tsmm`, probably due to +// // some rewrites not being applied. Might be a bug. +// String[] expectedHeavyHitters = new String[] {"fed_fedinit", "fed_ba+*", "fed_tak+*", "fed_max", +// "fed_1-*", "fed_>"}; +// setTestConf("SystemDS-config-cost-based.xml"); +// loadAndRunTest(expectedHeavyHitters, TEST_NAME); +// } +// +// private void setTestConf(String test_conf) { +// TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, test_conf); +// } +// +// private void writeInputMatrices() { +// writeBinaryVector("A", 42, rows, null); +// writeStandardMatrix("B1", 65, rows / 2, cols, null); +// writeStandardMatrix("B2", 75, rows / 2, cols, null); +// writeStandardMatrix("C1", 13, rows, cols / 2, null); +// writeStandardMatrix("C2", 17, rows, cols / 2, null); +// } +// +// private void writeBinaryVector(String matrixName, long seed, int numRows, PrivacyConstraint privacyConstraint){ +// double[][] matrix = getRandomMatrix(numRows, 1, -1, 1, 1, seed); +// for(int i = 0; i < numRows; i++) +// matrix[i][0] = (matrix[i][0] > 0) ? 1 : -1; +// MatrixCharacteristics mc = new MatrixCharacteristics(numRows, 1, blocksize, numRows); +// writeInputMatrixWithMTD(matrixName, matrix, false, mc, privacyConstraint); +// } +// +// private void writeStandardMatrix(String matrixName, long seed, int numRows, int numCols, +// PrivacyConstraint privacyConstraint) { +// double[][] matrix = getRandomMatrix(numRows, numCols, 0, 1, 1, seed); +// writeStandardMatrix(matrixName, numRows, numCols, privacyConstraint, matrix); +// } +// +// private void writeStandardMatrix(String matrixName, int numRows, int numCols, PrivacyConstraint privacyConstraint, +// double[][] matrix) { +// MatrixCharacteristics mc = new MatrixCharacteristics(numRows, numCols, blocksize, (long) numRows * numCols); +// writeInputMatrixWithMTD(matrixName, matrix, false, mc, privacyConstraint); +// } +// +// private void loadAndRunTest(String[] expectedHeavyHitters, String testName) { +// +// boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; +// Types.ExecMode platformOld = rtplatform; +// rtplatform = Types.ExecMode.SINGLE_NODE; +// +// Thread t1 = null, t2 = null; +// +// try { +// getAndLoadTestConfiguration(testName); +// String HOME = SCRIPT_DIR + TEST_DIR; +// +// writeInputMatrices(); +// +// int port1 = getRandomAvailablePort(); +// int port2 = getRandomAvailablePort(); +// t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); +// t2 = startLocalFedWorkerThread(port2); +// +// // Run actual dml script with federated matrix +// fullDMLScriptName = HOME + testName + ".dml"; +// programArgs = new String[] {"-stats", "-nvargs", +// "r=" + rows, "c=" + cols, +// "A=" + input("A"), +// "B1=" + TestUtils.federatedAddress(port1, input("B1")), +// "B2=" + TestUtils.federatedAddress(port2, input("B2")), +// "C1=" + TestUtils.federatedAddress(port1, input("C1")), +// "C2=" + TestUtils.federatedAddress(port2, input("C2")), +// "lB1=" + input("B1"), +// "lB2=" + input("B2"), +// "Z=" + output("Z")}; +// runTest(true, false, null, -1); +// +// // Run reference dml script with normal matrix +// fullDMLScriptName = HOME + testName + "Reference.dml"; +// programArgs = new String[] {"-nvargs", +// "r=" + rows, "c=" + cols, +// "A=" + input("A"), +// "B1=" + input("B1"), +// "B2=" + input("B2"), +// "C1=" + input("C1"), +// "C2=" + input("C2"), +// "Z=" + expected("Z")}; +// runTest(true, false, null, -1); +// +// // compare via files +// compareResults(1e-9); +// if(!heavyHittersContainsAllString(expectedHeavyHitters)) +// fail("The following expected heavy hitters are missing: " +// + Arrays.toString(missingHeavyHitters(expectedHeavyHitters))); +// } +// finally { +// TestUtils.shutdownThreads(t1, t2); +// rtplatform = platformOld; +// DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; +// } +// } +// +// /** +// * Override default configuration with custom test configuration to ensure scratch space and local temporary +// * directory locations are also updated. +// */ +// @Override +// protected File getConfigTemplateFile() { +// // Instrumentation in this test's output log to show custom configuration file used for template. +// LOG.info("This test case overrides default configuration with " + TEST_CONF_FILE.getPath()); +// return TEST_CONF_FILE; +// } +// +//} diff --git a/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedKMeansPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedKMeansPlanningTest.java new file mode 100644 index 00000000000..48d9a06b8c3 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedKMeansPlanningTest.java @@ -0,0 +1,168 @@ +package org.apache.sysds.test.functions.fedplanning; +///* +// * Licensed to the Apache Software Foundation (ASF) under one +// * or more contributor license agreements. See the NOTICE file +// * distributed with this work for additional information +// * regarding copyright ownership. The ASF licenses this file +// * to you under the Apache License, Version 2.0 (the +// * "License"); you may not use this file except in compliance +// * with the License. You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, +// * software distributed under the License is distributed on an +// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// * KIND, either express or implied. See the License for the +// * specific language governing permissions and limitations +// * under the License. +// */ +// +//package org.apache.sysds.test.functions.privacy.fedplanning; +// +//import org.apache.commons.logging.Log; +//import org.apache.commons.logging.LogFactory; +//import org.apache.sysds.api.DMLScript; +//import org.apache.sysds.common.Types; +//import org.apache.sysds.runtime.meta.MatrixCharacteristics; +//import org.apache.sysds.runtime.privacy.PrivacyConstraint; +//import org.apache.sysds.test.AutomatedTestBase; +//import org.apache.sysds.test.TestConfiguration; +//import org.apache.sysds.test.TestUtils; +//import org.junit.Ignore; +//import org.junit.Test; +// +//import java.io.File; +//import java.util.Arrays; +// +//import static org.junit.Assert.fail; +// +//public class FederatedKMeansPlanningTest extends AutomatedTestBase { +// private static final Log LOG = LogFactory.getLog(FederatedKMeansPlanningTest.class.getName()); +// +// private final static String TEST_DIR = "functions/privacy/fedplanning/"; +// private final static String TEST_NAME = "FederatedKMeansPlanningTest"; +// private final static String TEST_CLASS_DIR = TEST_DIR + FederatedKMeansPlanningTest.class.getSimpleName() + "/"; +// private static File TEST_CONF_FILE; +// +// private final static int blocksize = 1024; +// public final int rows = 1000; +// public final int cols = 100; +// +// @Override +// public void setUp() { +// TestUtils.clearAssertionInformation(); +// addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"})); +// } +// +// @Test +// public void runKMeansFOUTTest(){ +// String[] expectedHeavyHitters = new String[]{}; +// setTestConf("SystemDS-config-fout.xml"); +// loadAndRunTest(expectedHeavyHitters, TEST_NAME); +// } +// +// @Test +// public void runKMeansHeuristicTest(){ +// String[] expectedHeavyHitters = new String[]{}; +// setTestConf("SystemDS-config-heuristic.xml"); +// loadAndRunTest(expectedHeavyHitters, TEST_NAME); +// } +// +// @Test +// public void runKMeansCostBasedTest(){ +// String[] expectedHeavyHitters = new String[]{ "fed_fedinit", "fed_ba+*", "fed_*", "fed_uack+", "fed_bcumoffk+"}; +// setTestConf("SystemDS-config-cost-based.xml"); +// loadAndRunTest(expectedHeavyHitters, TEST_NAME); +// } +// +// @Test +// public void runRuntimeTest(){ +// String[] expectedHeavyHitters = new String[]{}; +// TEST_CONF_FILE = new File("src/test/config/SystemDS-config.xml"); +// loadAndRunTest(expectedHeavyHitters, TEST_NAME); +// } +// +// private void setTestConf(String test_conf){ +// TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, test_conf); +// } +// +// /** +// * Override default configuration with custom test configuration to ensure +// * scratch space and local temporary directory locations are also updated. +// */ +// @Override +// protected File getConfigTemplateFile() { +// // Instrumentation in this test's output log to show custom configuration file used for template. +// LOG.info("This test case overrides default configuration with " + TEST_CONF_FILE.getPath()); +// return TEST_CONF_FILE; +// } +// +// private void writeInputMatrices(){ +// writeStandardRowFedMatrix("X1", 65, null); +// writeStandardRowFedMatrix("X2", 75, null); +// } +// +// private void writeStandardMatrix(String matrixName, long seed, int numRows, PrivacyConstraint privacyConstraint){ +// double[][] matrix = getRandomMatrix(numRows, cols, 0, 1, 1, seed); +// writeStandardMatrix(matrixName, numRows, privacyConstraint, matrix); +// } +// +// private void writeStandardMatrix(String matrixName, int numRows, PrivacyConstraint privacyConstraint, double[][] matrix){ +// MatrixCharacteristics mc = new MatrixCharacteristics(numRows, cols, blocksize, (long) numRows * cols); +// writeInputMatrixWithMTD(matrixName, matrix, false, mc, privacyConstraint); +// } +// +// private void writeStandardRowFedMatrix(String matrixName, long seed, PrivacyConstraint privacyConstraint){ +// int halfRows = rows/2; +// writeStandardMatrix(matrixName, seed, halfRows, privacyConstraint); +// } +// +// private void loadAndRunTest(String[] expectedHeavyHitters, String testName){ +// +// boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; +// Types.ExecMode platformOld = rtplatform; +// rtplatform = Types.ExecMode.SINGLE_NODE; +// +// Thread t1 = null, t2 = null; +// +// try { +// getAndLoadTestConfiguration(testName); +// String HOME = SCRIPT_DIR + TEST_DIR; +// +// writeInputMatrices(); +// +// int port1 = getRandomAvailablePort(); +// int port2 = getRandomAvailablePort(); +// t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); +// t2 = startLocalFedWorkerThread(port2); +// +// // Run actual dml script with federated matrix +// fullDMLScriptName = HOME + testName + ".dml"; +// programArgs = new String[] { "-stats", "-nvargs", +// "X1=" + TestUtils.federatedAddress(port1, input("X1")), +// "X2=" + TestUtils.federatedAddress(port2, input("X2")), +// "Y=" + input("Y"), "r=" + rows, "c=" + cols, "Z=" + output("Z")}; +// runTest(true, false, null, -1); +// +// // Run reference dml script with normal matrix +// fullDMLScriptName = HOME + testName + "Reference.dml"; +// programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), +// "Y=" + input("Y"), "Z=" + expected("Z")}; +// runTest(true, false, null, -1); +// +// // compare via files +// compareResults(1e-9); +// if (!heavyHittersContainsAllString(expectedHeavyHitters)) +// fail("The following expected heavy hitters are missing: " +// + Arrays.toString(missingHeavyHitters(expectedHeavyHitters))); +// } +// finally { +// TestUtils.shutdownThreads(t1, t2); +// rtplatform = platformOld; +// DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; +// } +// } +// +// +//} diff --git a/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedL2SVMPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedL2SVMPlanningTest.java new file mode 100644 index 00000000000..0ef4fde6a45 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedL2SVMPlanningTest.java @@ -0,0 +1,202 @@ +package org.apache.sysds.test.functions.fedplanning; +///* +// * Licensed to the Apache Software Foundation (ASF) under one +// * or more contributor license agreements. See the NOTICE file +// * distributed with this work for additional information +// * regarding copyright ownership. The ASF licenses this file +// * to you under the Apache License, Version 2.0 (the +// * "License"); you may not use this file except in compliance +// * with the License. You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, +// * software distributed under the License is distributed on an +// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// * KIND, either express or implied. See the License for the +// * specific language governing permissions and limitations +// * under the License. +// */ +// +//package org.apache.sysds.test.functions.privacy.fedplanning; +// +//import org.apache.commons.logging.Log; +//import org.apache.commons.logging.LogFactory; +//import org.apache.sysds.api.DMLScript; +//import org.apache.sysds.common.Types; +//import org.apache.sysds.runtime.meta.MatrixCharacteristics; +//import org.apache.sysds.runtime.privacy.PrivacyConstraint; +//import org.apache.sysds.test.AutomatedTestBase; +//import org.apache.sysds.test.TestConfiguration; +//import org.apache.sysds.test.TestUtils; +//import org.junit.Ignore; +//import org.junit.Test; +// +//import java.io.File; +//import java.util.Arrays; +// +//import static org.junit.Assert.fail; +// +//@net.jcip.annotations.NotThreadSafe +//public class FederatedL2SVMPlanningTest extends AutomatedTestBase { +// private static final Log LOG = LogFactory.getLog(FederatedL2SVMPlanningTest.class.getName()); +// +// private final static String TEST_DIR = "functions/privacy/fedplanning/"; +// private final static String TEST_NAME = "FederatedL2SVMPlanningTest"; +// private final static String TEST_NAME_2 = "FederatedL2SVMFunctionPlanningTest"; +// private final static String TEST_CLASS_DIR = TEST_DIR + FederatedL2SVMPlanningTest.class.getSimpleName() + "/"; +// private static File TEST_CONF_FILE; +// +// private final static int blocksize = 1024; +// public final int rows = 1000; +// public final int cols = 100; +// +// @Override +// public void setUp() { +// TestUtils.clearAssertionInformation(); +// addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"})); +// addTestConfiguration(TEST_NAME_2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_2, new String[] {"Z"})); +// } +// +// @Test +// public void runL2SVMFOUTTest(){ +// String[] expectedHeavyHitters = new String[]{ "fed_fedinit", "fed_ba+*", "fed_tak+*", "fed_+*", +// "fed_max", "fed_1-*", "fed_tsmm", "fed_>"}; +// setTestConf("SystemDS-config-fout.xml"); +// loadAndRunTest(expectedHeavyHitters, TEST_NAME); +// } +// +// @Test +// public void runL2SVMHeuristicTest(){ +// String[] expectedHeavyHitters = new String[]{ "fed_fedinit", "fed_ba+*"}; +// setTestConf("SystemDS-config-heuristic.xml"); +// loadAndRunTest(expectedHeavyHitters, TEST_NAME); +// } +// +// @Test +// public void runL2SVMCostBasedTest(){ +// String[] expectedHeavyHitters = new String[]{ "fed_fedinit", "fed_ba+*", "fed_tak+*", "fed_+*", +// "fed_max", "fed_1-*", "fed_tsmm", "fed_>"}; +// setTestConf("SystemDS-config-cost-based.xml"); +// loadAndRunTest(expectedHeavyHitters, TEST_NAME); +// } +// +// @Test +// public void runL2SVMFunctionFOUTTest(){ +// String[] expectedHeavyHitters = new String[]{ "fed_fedinit", "fed_ba+*", "fed_tak+*", "fed_+*", +// "fed_max", "fed_1-*", "fed_tsmm", "fed_>"}; +// setTestConf("SystemDS-config-fout.xml"); +// loadAndRunTest(expectedHeavyHitters, TEST_NAME_2); +// } +// +// @Test +// public void runL2SVMFunctionHeuristicTest(){ +// String[] expectedHeavyHitters = new String[]{ "fed_fedinit", "fed_ba+*"}; +// setTestConf("SystemDS-config-heuristic.xml"); +// loadAndRunTest(expectedHeavyHitters, TEST_NAME_2); +// } +// +// @Test +// public void runL2SVMFunctionCostBasedTest(){ +// String[] expectedHeavyHitters = new String[]{ "fed_fedinit", "fed_ba+*", "fed_tak+*", "fed_+*", +// "fed_max", "fed_1-*", "fed_tsmm", "fed_>"}; +// setTestConf("SystemDS-config-cost-based.xml"); +// loadAndRunTest(expectedHeavyHitters, TEST_NAME_2); +// } +// +// private void setTestConf(String test_conf){ +// TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, test_conf); +// } +// +// private void writeInputMatrices(){ +// writeStandardRowFedMatrix("X1", 65, null); +// writeStandardRowFedMatrix("X2", 75, null); +// writeBinaryVector("Y", 44, null); +// } +// +// private void writeBinaryVector(String matrixName, long seed, PrivacyConstraint privacyConstraint){ +// double[][] matrix = getRandomMatrix(rows, 1, -1, 1, 1, seed); +// for(int i = 0; i < rows; i++) +// matrix[i][0] = (matrix[i][0] > 0) ? 1 : -1; +// MatrixCharacteristics mc = new MatrixCharacteristics(rows, 1, blocksize, rows); +// writeInputMatrixWithMTD(matrixName, matrix, false, mc, privacyConstraint); +// } +// +// @SuppressWarnings("unused") +// private void writeStandardMatrix(String matrixName, long seed, PrivacyConstraint privacyConstraint){ +// writeStandardMatrix(matrixName, seed, rows, privacyConstraint); +// } +// +// private void writeStandardMatrix(String matrixName, long seed, int numRows, PrivacyConstraint privacyConstraint){ +// double[][] matrix = getRandomMatrix(numRows, cols, 0, 1, 1, seed); +// writeStandardMatrix(matrixName, numRows, privacyConstraint, matrix); +// } +// +// private void writeStandardMatrix(String matrixName, int numRows, PrivacyConstraint privacyConstraint, double[][] matrix){ +// MatrixCharacteristics mc = new MatrixCharacteristics(numRows, cols, blocksize, (long) numRows * cols); +// writeInputMatrixWithMTD(matrixName, matrix, false, mc, privacyConstraint); +// } +// +// private void writeStandardRowFedMatrix(String matrixName, long seed, PrivacyConstraint privacyConstraint){ +// int halfRows = rows/2; +// writeStandardMatrix(matrixName, seed, halfRows, privacyConstraint); +// } +// +// private void loadAndRunTest(String[] expectedHeavyHitters, String testName){ +// +// boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; +// Types.ExecMode platformOld = rtplatform; +// rtplatform = Types.ExecMode.SINGLE_NODE; +// +// Thread t1 = null, t2 = null; +// +// try { +// getAndLoadTestConfiguration(testName); +// String HOME = SCRIPT_DIR + TEST_DIR; +// +// writeInputMatrices(); +// +// int port1 = getRandomAvailablePort(); +// int port2 = getRandomAvailablePort(); +// t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); +// t2 = startLocalFedWorkerThread(port2); +// +// // Run actual dml script with federated matrix +// fullDMLScriptName = HOME + testName + ".dml"; +// programArgs = new String[] { "-stats", "-nvargs", +// "X1=" + TestUtils.federatedAddress(port1, input("X1")), +// "X2=" + TestUtils.federatedAddress(port2, input("X2")), +// "Y=" + input("Y"), "r=" + rows, "c=" + cols, "Z=" + output("Z")}; +// runTest(true, false, null, -1); +// +// // Run reference dml script with normal matrix +// fullDMLScriptName = HOME + testName + "Reference.dml"; +// programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), +// "Y=" + input("Y"), "Z=" + expected("Z")}; +// runTest(true, false, null, -1); +// +// // compare via files +// compareResults(1e-9); +// if (!heavyHittersContainsAllString(expectedHeavyHitters)) +// fail("The following expected heavy hitters are missing: " +// + Arrays.toString(missingHeavyHitters(expectedHeavyHitters))); +// } +// finally { +// TestUtils.shutdownThreads(t1, t2); +// rtplatform = platformOld; +// DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; +// } +// } +// +// /** +// * Override default configuration with custom test configuration to ensure +// * scratch space and local temporary directory locations are also updated. +// */ +// @Override +// protected File getConfigTemplateFile() { +// // Instrumentation in this test's output log to show custom configuration file used for template. +// LOG.info("This test case overrides default configuration with " + TEST_CONF_FILE.getPath()); +// return TEST_CONF_FILE; +// } +// +//} diff --git a/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedMultiplyPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedMultiplyPlanningTest.java new file mode 100644 index 00000000000..f3eee4ee41c --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedMultiplyPlanningTest.java @@ -0,0 +1,334 @@ +package org.apache.sysds.test.functions.fedplanning; +///* +// * Licensed to the Apache Software Foundation (ASF) under one +// * or more contributor license agreements. See the NOTICE file +// * distributed with this work for additional information +// * regarding copyright ownership. The ASF licenses this file +// * to you under the Apache License, Version 2.0 (the +// * "License"); you may not use this file except in compliance +// * with the License. You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, +// * software distributed under the License is distributed on an +// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// * KIND, either express or implied. See the License for the +// * specific language governing permissions and limitations +// * under the License. +// */ +// +//package org.apache.sysds.test.functions.privacy.fedplanning; +// +//import org.apache.commons.logging.Log; +//import org.apache.commons.logging.LogFactory; +//import org.apache.sysds.runtime.privacy.PrivacyConstraint; +//import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel; +//import org.junit.Test; +//import org.junit.runner.RunWith; +//import org.junit.runners.Parameterized; +//import org.apache.sysds.api.DMLScript; +//import org.apache.sysds.common.Types; +//import org.apache.sysds.runtime.meta.MatrixCharacteristics; +//import org.apache.sysds.test.AutomatedTestBase; +//import org.apache.sysds.test.TestConfiguration; +//import org.apache.sysds.test.TestUtils; +// +//import java.io.File; +//import java.util.Arrays; +//import java.util.Collection; +// +//import static org.junit.Assert.fail; +// +//@RunWith(value = Parameterized.class) +//@net.jcip.annotations.NotThreadSafe +//public class FederatedMultiplyPlanningTest extends AutomatedTestBase { +// private static final Log LOG = LogFactory.getLog(FederatedMultiplyPlanningTest.class.getName()); +// +// private final static String TEST_DIR = "functions/privacy/fedplanning/"; +// private final static String TEST_NAME = "FederatedMultiplyPlanningTest"; +// private final static String TEST_NAME_2 = "FederatedMultiplyPlanningTest2"; +// private final static String TEST_NAME_3 = "FederatedMultiplyPlanningTest3"; +// private final static String TEST_NAME_4 = "FederatedMultiplyPlanningTest4"; +// private final static String TEST_NAME_5 = "FederatedMultiplyPlanningTest5"; +// private final static String TEST_NAME_6 = "FederatedMultiplyPlanningTest6"; +// private final static String TEST_NAME_7 = "FederatedMultiplyPlanningTest7"; +// private final static String TEST_NAME_8 = "FederatedMultiplyPlanningTest8"; +// private final static String TEST_NAME_9 = "FederatedMultiplyPlanningTest9"; +// private final static String TEST_NAME_10 = "FederatedMultiplyPlanningTest10"; +// private final static String TEST_NAME_11 = "FederatedMultiplyPlanningTest11"; +// private final static String TEST_NAME_12 = "FederatedMultiplyPlanningTest12"; +// private final static String TEST_CLASS_DIR = TEST_DIR + FederatedMultiplyPlanningTest.class.getSimpleName() + "/"; +// private static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, "SystemDS-config-cost-based.xml"); +// +// private final static int blocksize = 1024; +// @Parameterized.Parameter() +// public int rows; +// @Parameterized.Parameter(1) +// public int cols; +// +// @Override +// public void setUp() { +// TestUtils.clearAssertionInformation(); +// addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"})); +// addTestConfiguration(TEST_NAME_2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_2, new String[] {"Z"})); +// addTestConfiguration(TEST_NAME_3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_3, new String[] {"Z.scalar"})); +// addTestConfiguration(TEST_NAME_4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_4, new String[] {"Z"})); +// addTestConfiguration(TEST_NAME_5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_5, new String[] {"Z"})); +// addTestConfiguration(TEST_NAME_6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_6, new String[] {"Z"})); +// addTestConfiguration(TEST_NAME_7, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_7, new String[] {"Z"})); +// addTestConfiguration(TEST_NAME_8, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_8, new String[] {"Z.scalar"})); +// addTestConfiguration(TEST_NAME_9, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_9, new String[] {"Z.scalar"})); +// addTestConfiguration(TEST_NAME_10, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_10, new String[] {"Z"})); +// addTestConfiguration(TEST_NAME_11, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_11, new String[] {"Z"})); +// addTestConfiguration(TEST_NAME_12, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_12, new String[] {"Z"})); +// } +// +// @Parameterized.Parameters +// public static Collection data() { +// // rows have to be even and > 1 +// return Arrays.asList(new Object[][] { +// {100, 10} +// }); +// } +// +// @Test +// public void federatedMultiplyCP() { +// String[] expectedHeavyHitters = new String[]{"fed_*", "fed_fedinit", "fed_r'", "fed_ba+*"}; +// federatedTwoMatricesSingleNodeTest(TEST_NAME, expectedHeavyHitters); +// } +// +// @Test +// public void federatedRowSum(){ +// String[] expectedHeavyHitters = new String[]{"fed_*", "fed_r'", "fed_fedinit", "fed_ba+*", "fed_uark+"}; +// federatedTwoMatricesSingleNodeTest(TEST_NAME_2, expectedHeavyHitters); +// } +// +// @Test +// public void federatedTernarySequence(){ +// String[] expectedHeavyHitters = new String[]{"fed_+*", "fed_1-*", "fed_fedinit", "fed_uak+"}; +// federatedTwoMatricesSingleNodeTest(TEST_NAME_3, expectedHeavyHitters); +// } +// +// @Test +// public void federatedAggregateBinarySequence(){ +// cols = rows; +// String[] expectedHeavyHitters = new String[]{"fed_ba+*", "fed_*", "fed_fedinit"}; +// federatedTwoMatricesSingleNodeTest(TEST_NAME_4, expectedHeavyHitters); +// } +// +// @Test +// public void federatedAggregateBinaryColFedSequence(){ +// cols = rows; +// //TODO: When alignment checks have been added to getFederatedOut in AFederatedPlanner, +// // the following expectedHeavyHitters can be added. Until then, fed_* will not be generated. +// //String[] expectedHeavyHitters = new String[]{"fed_ba+*","fed_*","fed_fedinit"}; +// String[] expectedHeavyHitters = new String[]{"fed_ba+*","fed_fedinit"}; +// federatedTwoMatricesSingleNodeTest(TEST_NAME_5, expectedHeavyHitters); +// } +// +// @Test +// public void federatedAggregateBinarySequence2(){ +// String[] expectedHeavyHitters = new String[]{"fed_ba+*","fed_fedinit"}; +// federatedTwoMatricesSingleNodeTest(TEST_NAME_6, expectedHeavyHitters); +// } +// +// @Test +// public void federatedMultiplyDoubleHop() { +// String[] expectedHeavyHitters = new String[]{"fed_*", "fed_fedinit", "fed_r'", "fed_ba+*"}; +// federatedTwoMatricesSingleNodeTest(TEST_NAME_7, expectedHeavyHitters); +// } +// +// @Test +// public void federatedMultiplyDoubleHop2() { +// String[] expectedHeavyHitters = new String[]{"fed_fedinit", "fed_ba+*"}; +// federatedTwoMatricesSingleNodeTest(TEST_NAME_8, expectedHeavyHitters); +// } +// +// @Test +// public void federatedMultiplyPlanningTest9(){ +// String[] expectedHeavyHitters = new String[]{"fed_+*", "fed_1-*", "fed_fedinit", "fed_tak+*", "fed_max"}; +// federatedTwoMatricesSingleNodeTest(TEST_NAME_9, expectedHeavyHitters); +// } +// +// @Test +// public void federatedMultiplyPlanningTest10(){ +// String[] expectedHeavyHitters = new String[]{"fed_fedinit", "fed_^2"}; +// TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, "SystemDS-config-fout.xml"); +// federatedTwoMatricesSingleNodeTest(TEST_NAME_10, expectedHeavyHitters); +// } +// +// @Test +// public void federatedMultiplyPlanningTest11(){ +// String[] expectedHeavyHitters = new String[]{"fed_fedinit"}; +// federatedTwoMatricesSingleNodeTest(TEST_NAME_11, expectedHeavyHitters); +// } +// +// @Test +// public void federatedMultiplyPlanningTest12(){ +// String[] expectedHeavyHitters = new String[]{"fed_fedinit"}; +// rows = 30; +// cols = 30; +// federatedTwoMatricesSingleNodeTest(TEST_NAME_12, expectedHeavyHitters); +// } +// +// private void writeStandardMatrix(String matrixName, long seed){ +// writeStandardMatrix(matrixName, seed, new PrivacyConstraint(PrivacyConstraint.PrivacyLevel.PrivateAggregation)); +// } +// +// private void writeStandardMatrix(String matrixName, long seed, PrivacyConstraint privacyConstraint){ +// int halfRows = rows/2; +// double[][] matrix = getRandomMatrix(halfRows, cols, 0, 1, 1, seed); +// MatrixCharacteristics mc = new MatrixCharacteristics(halfRows, cols, blocksize, (long) halfRows * cols); +// writeInputMatrixWithMTD(matrixName, matrix, false, mc, privacyConstraint); +// } +// +// private void writeColStandardMatrix(String matrixName, long seed){ +// writeColStandardMatrix(matrixName, seed, new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); +// } +// +// private void writeColStandardMatrix(String matrixName, long seed, PrivacyConstraint privacyConstraint){ +// int halfCols = cols/2; +// double[][] matrix = getRandomMatrix(rows, halfCols, 0, 1, 1, seed); +// MatrixCharacteristics mc = new MatrixCharacteristics(rows, halfCols, blocksize, (long) halfCols *rows); +// writeInputMatrixWithMTD(matrixName, matrix, false, mc, privacyConstraint); +// } +// +// private void writeRowFederatedVector(String matrixName, long seed){ +// writeRowFederatedVector(matrixName, seed, new PrivacyConstraint(PrivacyLevel.PrivateAggregation)); +// } +// +// private void writeRowFederatedVector(String matrixName, long seed, PrivacyConstraint privacyConstraint){ +// int halfCols = cols / 2; +// double[][] matrix = getRandomMatrix(halfCols, 1, 0, 1, 1, seed); +// MatrixCharacteristics mc = new MatrixCharacteristics(halfCols, 1, blocksize, (long) halfCols *rows); +// writeInputMatrixWithMTD(matrixName, matrix, false, mc, privacyConstraint); +// } +// +// private void writeInputMatrices(String testName){ +// if ( testName.equals(TEST_NAME_5) ){ +// writeColStandardMatrix("X1", 42); +// writeColStandardMatrix("X2", 1340); +// writeColStandardMatrix("Y1", 44, null); +// writeColStandardMatrix("Y2", 21, null); +// } +// else if ( testName.equals(TEST_NAME_6) ){ +// writeColStandardMatrix("X1", 42); +// writeColStandardMatrix("X2", 1340); +// writeRowFederatedVector("Y1", 44); +// writeRowFederatedVector("Y2", 21); +// } +// else if ( testName.equals(TEST_NAME_8) ){ +// writeColStandardMatrix("X1", 42, null); +// writeColStandardMatrix("X2", 1340, null); +// writeColStandardMatrix("Y1", 44, null); +// writeColStandardMatrix("Y2", 21, null); +// writeColStandardMatrix("W1", 76, null); +// writeColStandardMatrix("W2", 11, null); +// } +// else if ( testName.equals(TEST_NAME_10) || testName.equals(TEST_NAME_12) ){ +// writeStandardMatrix("X1", 42, null); +// writeStandardMatrix("X2", 1340, null); +// } +// else { +// writeStandardMatrix("X1", 42); +// writeStandardMatrix("X2", 1340); +// if ( testName.equals(TEST_NAME_4) ){ +// writeStandardMatrix("Y1", 44, null); +// writeStandardMatrix("Y2", 21, null); +// } +// else { +// writeStandardMatrix("Y1", 44); +// writeStandardMatrix("Y2", 21); +// } +// } +// } +// +// private void federatedTwoMatricesSingleNodeTest(String testName, String[] expectedHeavyHitters){ +// federatedTwoMatricesTest(Types.ExecMode.SINGLE_NODE, testName, expectedHeavyHitters); +// } +// +// private void federatedTwoMatricesTest(Types.ExecMode execMode, String testName, String[] expectedHeavyHitters) { +// boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; +// Types.ExecMode platformOld = rtplatform; +// rtplatform = execMode; +// if(rtplatform == Types.ExecMode.SPARK) { +// DMLScript.USE_LOCAL_SPARK_CONFIG = true; +// } +// Thread t1 = null, t2 = null; +// +// try{ +// getAndLoadTestConfiguration(testName); +// String HOME = SCRIPT_DIR + TEST_DIR; +// +// writeInputMatrices(testName); +// +// int port1 = getRandomAvailablePort(); +// int port2 = getRandomAvailablePort(); +// t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S); +// t2 = startLocalFedWorkerThread(port2); +// +// // Run actual dml script with federated matrix +// fullDMLScriptName = HOME + testName + ".dml"; +// programArgs = new String[] {"-stats", "-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")), +// "X2=" + TestUtils.federatedAddress(port2, input("X2")), +// "Y1=" + TestUtils.federatedAddress(port1, input("Y1")), +// "Y2=" + TestUtils.federatedAddress(port2, input("Y2")), "r=" + rows, "c=" + cols, "Z=" + output("Z")}; +// rewriteRealProgramArgs(testName, port1, port2); +// runTest(true, false, null, -1); +// +// // Run reference dml script with normal matrix +// fullDMLScriptName = HOME + testName + "Reference.dml"; +// programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"), +// "Y2=" + input("Y2"), "Z=" + expected("Z")}; +// rewriteReferenceProgramArgs(testName); +// runTest(true, false, null, -1); +// +// // compare via files +// compareResults(1e-9); +// if (!heavyHittersContainsAllString(expectedHeavyHitters)) +// fail("The following expected heavy hitters are missing: " +// + Arrays.toString(missingHeavyHitters(expectedHeavyHitters))); +// } finally { +// TestUtils.shutdownThreads(t1, t2); +// rtplatform = platformOld; +// DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; +// } +// } +// +// private void rewriteRealProgramArgs(String testName, int port1, int port2){ +// if ( testName.equals(TEST_NAME_4) || testName.equals(TEST_NAME_5) ){ +// programArgs = new String[] {"-stats","-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")), +// "X2=" + TestUtils.federatedAddress(port2, input("X2")), +// "Y1=" + input("Y1"), +// "Y2=" + input("Y2"), "r=" + rows, "c=" + cols, "Z=" + output("Z")}; +// } else if ( testName.equals(TEST_NAME_8) ){ +// programArgs = new String[] {"-stats","-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")), +// "X2=" + TestUtils.federatedAddress(port2, input("X2")), +// "Y1=" + TestUtils.federatedAddress(port1, input("Y1")), +// "Y2=" + TestUtils.federatedAddress(port2, input("Y2")), +// "W1=" + input("W1"), +// "W2=" + input("W2"), +// "r=" + rows, "c=" + cols, "Z=" + output("Z")}; +// } +// } +// +// private void rewriteReferenceProgramArgs(String testName){ +// if ( testName.equals(TEST_NAME_8) ){ +// programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"), +// "Y2=" + input("Y2"), "W1=" + input("W1"), "W2=" + input("W2"), "Z=" + expected("Z")}; +// } +// } +// +// /** +// * Override default configuration with custom test configuration to ensure +// * scratch space and local temporary directory locations are also updated. +// */ +// @Override +// protected File getConfigTemplateFile() { +// // Instrumentation in this test's output log to show custom configuration file used for template. +// LOG.info("This test case overrides default configuration with " + TEST_CONF_FILE.getPath()); +// return TEST_CONF_FILE; +// } +//} +//