From 29b3c612deb2312130987cc906d6790e7c7187ca Mon Sep 17 00:00:00 2001 From: min-guk Date: Sun, 20 Oct 2024 18:15:36 +0200 Subject: [PATCH] [SYSTEMDS-3729] Add missing federated roll reorg operations Closes #2126. --- .../federated/FederationMap.java | 31 ++- .../instructions/FEDInstructionParser.java | 1 + .../instructions/cp/ReorgCPInstruction.java | 2 +- .../instructions/fed/ReorgFEDInstruction.java | 145 +++++++++++++- .../instructions/fed/UnaryFEDInstruction.java | 6 +- .../spark/ReorgSPInstruction.java | 2 +- .../primitives/part2/FederatedRollTest.java | 187 ++++++++++++++++++ .../functions/federated/FederatedRollTest.dml | 32 +++ .../federated/FederatedRollTestReference.dml | 26 +++ 9 files changed, 422 insertions(+), 10 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java create mode 100644 src/test/scripts/functions/federated/FederatedRollTest.dml create mode 100644 src/test/scripts/functions/federated/FederatedRollTestReference.dml diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java index 985fdb056e5..91e6c156c4b 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java @@ -406,14 +406,40 @@ public Future[] executeMultipleSlices(long tid, boolean wait, return ret.toArray(new Future[0]); } + @SuppressWarnings("unchecked") + public Future[] executeRoll(long tid, boolean wait, + FederatedRequest frEnd, FederatedRequest frStart, long rlen) + { + // executes step1[] - step 2 - ... step4 (only first step federated-data-specific) + setThreadID(tid, new FederatedRequest[]{frStart, frEnd}); + List> ret = new ArrayList<>(); + + for(Pair e : _fedMap) { + if (e.getKey().getEndDims()[0] == rlen) { + ret.add(e.getValue().executeFederatedOperation(frEnd)); + } else if (e.getKey().getBeginDims()[0] == 0){ + ret.add(e.getValue().executeFederatedOperation(frStart)); + } + } + + // prepare results (future federated responses), with optional wait to ensure the + // order of requests without data dependencies (e.g., cleanup RPCs) + if(wait) + FederationUtils.waitFor(ret); + return (Future[])ret.toArray(new Future[0]); + } + public List>> requestFederatedData() { if(!isInitialized()) throw new DMLRuntimeException("Federated matrix read only supported on initialized FederatedData"); List>> readResponses = new ArrayList<>(); - FederatedRequest request = new FederatedRequest(RequestType.GET_VAR, _ID); - for(Pair e : _fedMap) + + for(Pair e : _fedMap){ + FederatedRequest request = new FederatedRequest(RequestType.GET_VAR, e.getValue().getVarID()); readResponses.add(Pair.of(e.getKey(), e.getValue().executeFederatedOperation(request))); + } + return readResponses; } @@ -692,6 +718,7 @@ public void reverseFedMap() { } } + private static class MappingTask implements Callable { private final FederatedRange _range; private final FederatedData _data; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java index f61e86e800b..820d07031d6 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java @@ -86,6 +86,7 @@ public class FEDInstructionParser extends InstructionParser String2FEDInstructionType.put( "r'" , FEDType.Reorg ); String2FEDInstructionType.put( "rdiag" , FEDType.Reorg ); String2FEDInstructionType.put( "rev" , FEDType.Reorg ); + String2FEDInstructionType.put( "roll" , FEDType.Reorg ); //String2FEDInstructionType.put( "rshape" , FEDType.Reorg ); Not supported by ReorgFEDInstruction parser! //String2FEDInstructionType.put( "rsort" , FEDType.Reorg ); Not supported by ReorgFEDInstruction parser! diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java index e7b3000d52e..ab105a95855 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java @@ -86,7 +86,7 @@ private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand c * @param istr ? */ private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand shift, String opcode, String istr) { - super(CPType.Reorg, op, in, out, opcode, istr); + super(CPType.Reorg, op, in, shift, out, opcode, istr); _col = null; _desc = null; _ixret = null; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java index c10ca272593..2c8748f7835 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java @@ -36,6 +36,7 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.federated.FederatedRange; +import org.apache.sysds.runtime.controlprogram.federated.FederatedData; import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; @@ -43,6 +44,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; import org.apache.sysds.runtime.functionobjects.DiagIndex; import org.apache.sysds.runtime.functionobjects.RevIndex; +import org.apache.sysds.runtime.functionobjects.RollIndex; import org.apache.sysds.runtime.functionobjects.SwapIndex; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; @@ -57,6 +59,8 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics; public class ReorgFEDInstruction extends UnaryFEDInstruction { + // roll-specific attributes + private CPOperand _shift = null; public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opcode, String istr, FederatedOutput fedOut) { super(FEDType.Reorg, op, in1, out, opcode, istr, fedOut); @@ -66,14 +70,29 @@ public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opc super(FEDType.Reorg, op, in1, out, opcode, istr); } + private ReorgFEDInstruction(Operator op, CPOperand in, CPOperand shift, CPOperand out, String opcode, String istr, FederatedOutput fedOut) { + super(FEDType.Reorg, op, in, shift, out, opcode, istr, fedOut); + _shift = shift; + } + public static ReorgFEDInstruction parseInstruction(ReorgCPInstruction rinst) { - return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(), - rinst.getInstructionString(), FederatedOutput.NONE); + if (rinst.input2 != null) { + return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.input2, rinst.output, rinst.getOpcode(), + rinst.getInstructionString(), FederatedOutput.NONE); + } else{ + return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(), + rinst.getInstructionString(), FederatedOutput.NONE); + } } public static ReorgFEDInstruction parseInstruction(ReorgSPInstruction rinst) { - return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(), - rinst.getInstructionString(), FederatedOutput.NONE); + if (rinst.input2 != null) { + return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.input2, rinst.output, rinst.getOpcode(), + rinst.getInstructionString(), FederatedOutput.NONE); + } else{ + return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(), + rinst.getInstructionString(), FederatedOutput.NONE); + } } public static ReorgFEDInstruction parseInstruction(String str) { @@ -105,6 +124,15 @@ else if(opcode.equalsIgnoreCase("rev")) { return new ReorgFEDInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str, fedOut); } + else if (opcode.equalsIgnoreCase("roll")) { + InstructionUtils.checkNumFields(str, 3); + in.split(parts[1]); + out.split(parts[3]); + CPOperand shift = new CPOperand(parts[2]); + fedOut = parseFedOutFlag(str, 3); + return new ReorgFEDInstruction(new ReorgOperator(new RollIndex(0)), + in, out, shift, opcode, str, fedOut); + } else { throw new DMLRuntimeException("ReorgFEDInstruction: unsupported opcode: " + opcode); } @@ -167,6 +195,36 @@ else if(instOpcode.equalsIgnoreCase("rev")) { .setBlocksize(mo1.getBlocksize()).setNonZeros(nnz); out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID())); + optionalForceLocal(out); + } else if (instOpcode.equalsIgnoreCase("roll")) { + long rlen = mo1.getNumRows(); + long shift = ec.getScalarInput(_shift).getLongValue(); + shift %= (rlen != 0 ? rlen : 1); // roll matrix with axis=none + + long inID = mo1.getFedMapping().getID(); + long outEndID = FederationUtils.getNextFedDataID(); + long outStartID = FederationUtils.getNextFedDataID(); + + List> inMap = mo1.getFedMapping().getMap(); + Pair rollResult = rollFedMap( + inMap, inID, outEndID, outStartID, shift, rlen, mo1.getFedMapping().getType()); + long length = rollResult.getValue(); + FederationMap outFedMap = rollResult.getKey(); + + FederatedRequest frEnd = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, outEndID, + new ReorgFEDInstruction.SliceMatrix(inID, outEndID, length, true)); + FederatedRequest frStart = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, outStartID, + new ReorgFEDInstruction.SliceMatrix(inID, outStartID, length, false)); + Future[] ffr = outFedMap.executeRoll(getTID(), true, frEnd, frStart, rlen); + + //derive output federated mapping + MatrixObject out = ec.getMatrixObject(output); + long nnz = (mo1.getNnz() != -1) ? mo1.getNnz() : FederationUtils.sumNonZeros(ffr); + out.getDataCharacteristics() + .setDimension(mo1.getNumRows(), mo1.getNumColumns()) + .setBlocksize(mo1.getBlocksize()) + .setNonZeros(nnz); + out.setFedMapping(outFedMap); optionalForceLocal(out); } else if (instOpcode.equals("rdiag")) { @@ -189,6 +247,40 @@ else if (instOpcode.equals("rdiag")) { } } + + public Pair rollFedMap(List> oldMap, long inID, + long outEndID, long outStartID, long shift, long rlen, FType type) { + List> map = new ArrayList<>(); + long length = 0; + + for(Map.Entry e : oldMap) { + if(e.getKey().getSize() == 0) continue; + FederatedRange fedRange = new FederatedRange(e.getKey()); + long beginRow = fedRange.getBeginDims()[0] + shift; + long endRow = fedRange.getEndDims()[0] + shift; + + beginRow = beginRow > rlen ? beginRow - rlen : beginRow; + endRow = endRow > rlen ? endRow - rlen : endRow; + + if (beginRow < endRow) { + fedRange.setBeginDim(0, beginRow); + fedRange.setEndDim(0, endRow); + map.add(Pair.of(fedRange, e.getValue().copyWithNewID(inID))); + } else { + length = rlen - beginRow; + fedRange.setBeginDim(0, beginRow); + fedRange.setEndDim(0, rlen); + map.add(Pair.of(fedRange, e.getValue().copyWithNewID(outEndID))); + + FederatedRange startRange = new FederatedRange(fedRange); + startRange.setBeginDim(0, 0); + startRange.setEndDim(0, endRow); + map.add(Pair.of(startRange, e.getValue().copyWithNewID(outStartID))); + } + } + return Pair.of(new FederationMap(outEndID, map, type), length); + } + /** * Update the federated ranges of result and return the updated federation map. * @param result RdiagResult for which the fedmap is updated @@ -307,6 +399,51 @@ private RdiagResult rdiagM2V (MatrixObject mo1, ReorgOperator r_op) { return new RdiagResult(diagFedMap, dcs); } + public static class SliceMatrix extends FederatedUDF { + private static final long serialVersionUID = -3466926635958851402L; + private final long _outputID; + private final int _sliceRow; + private final boolean _isRight; + + private SliceMatrix(long input, long outputID, long sliceRow, boolean isRight) { + super(new long[] {input}); + _outputID = outputID; + _sliceRow = (int) sliceRow; + _isRight = isRight; + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + MatrixBlock oriBlock = ((MatrixObject) data[0]).acquireReadAndRelease(); + MatrixBlock resBlock; + + if (_sliceRow != 0){ + if (_isRight){ + resBlock = oriBlock.slice(0, _sliceRow-1, 0, + oriBlock.getNumColumns()-1, new MatrixBlock()); + } else{ + resBlock = oriBlock.slice(_sliceRow, oriBlock.getNumRows()-1, + 0, oriBlock.getNumColumns()-1, new MatrixBlock()); + } + } else{ + resBlock = oriBlock; + } + ec.setMatrixOutput(String.valueOf(_outputID), resBlock); + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, resBlock); + } + + @Override + public List getOutputIds() { + return new ArrayList<>(Arrays.asList(_outputID)); + } + + @Override + public Pair getLineageItem(ExecutionContext ec) { + return Pair.of(String.valueOf(_outputID), + new LineageItem()); + } + } + public static class Rdiag extends FederatedUDF { private static final long serialVersionUID = -3466926635958851402L; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java index f025983e741..2311a1afe26 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java @@ -88,7 +88,8 @@ public static UnaryFEDInstruction parseInstruction(UnaryCPInstruction inst, Exec } } else if(inst instanceof ReorgCPInstruction && - (inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag") || inst.getOpcode().equals("rev"))) { + (inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag") + || inst.getOpcode().equals("rev") || inst.getOpcode().equals("roll"))) { ReorgCPInstruction rinst = (ReorgCPInstruction) inst; CacheableData mo = ec.getCacheableData(rinst.input1); @@ -157,7 +158,8 @@ else if(inst instanceof AggregateUnarySPInstruction) { return AggregateUnaryFEDInstruction.parseInstruction(auinstruction); } else if(inst instanceof ReorgSPInstruction && - (inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag") || inst.getOpcode().equals("rev"))) { + (inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag") + || inst.getOpcode().equals("rev") || inst.getOpcode().equals("roll"))) { ReorgSPInstruction rinst = (ReorgSPInstruction) inst; CacheableData mo = ec.getCacheableData(rinst.input1); if((mo instanceof MatrixObject || mo instanceof FrameObject) && mo.isFederated() && diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java index b096405959b..1a4f8fef0da 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java @@ -85,7 +85,7 @@ private ReorgSPInstruction(Operator op, CPOperand in, CPOperand col, CPOperand d } private ReorgSPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand shift, String opcode, String istr) { - this(op, in, out, opcode, istr); + super(SPType.Reorg, op, in, shift, null, out, opcode, istr); _shift = shift; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java new file mode 100644 index 00000000000..f242710338d --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.federated.primitives.part2; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; + +@RunWith(value = Parameterized.class) +@net.jcip.annotations.NotThreadSafe +public class FederatedRollTest extends AutomatedTestBase { + // private static final Log LOG = LogFactory.getLog(FederatedRightIndexTest.class.getName()); + + private final static String TEST_NAME = "FederatedRollTest"; + + private final static String TEST_DIR = "functions/federated/"; + private static final String TEST_CLASS_DIR = TEST_DIR + FederatedRollTest.class.getSimpleName() + "/"; + + private final static int blocksize = 1024; + @Parameterized.Parameter() + public int rows; + @Parameterized.Parameter(1) + public int cols; + + @Parameterized.Parameter(2) + public boolean rowPartitioned; + + @Parameterized.Parameters + public static Collection data() { + return Arrays.asList(new Object[][]{{100, 12, true}, {100, 12, false}}); + } + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"S"})); + } + + @Test + public void testRollCP() { + runRollTest(ExecMode.SINGLE_NODE); + } + + @Test + @Ignore + public void testRollSP() { + runRollTest(ExecMode.SPARK); + } + + @Test + public void federatedCompilationRollCP() { + runRollTest(ExecMode.SINGLE_NODE, true); + } + + @Test + @Ignore + public void federatedCompilationRollSP() { + runRollTest(ExecMode.SPARK, true); + } + + private void runRollTest(ExecMode execMode) { + runRollTest(execMode, false); + } + + private void runRollTest(ExecMode execMode, boolean activateFedCompilation) { + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + ExecMode platformOld = rtplatform; + + if (rtplatform == ExecMode.SPARK) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + + // write input matrices + int r = rows; + int c = cols / 4; + if (rowPartitioned) { + r = rows / 4; + c = cols; + } + + double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3); + double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7); + double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8); + double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9); + + for (int k : new int[]{1, 2, 3}) { + Arrays.fill(X3[k], 0); + } + + MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c); + writeInputMatrixWithMTD("X1", X1, false, mc); + writeInputMatrixWithMTD("X2", X2, false, mc); + writeInputMatrixWithMTD("X3", X3, false, mc); + writeInputMatrixWithMTD("X4", X4, false, mc); + + // empty script name because we don't execute any script, just start the worker + fullDMLScriptName = ""; + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + int port3 = getRandomAvailablePort(); + int port4 = getRandomAvailablePort(); + Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S); + Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S); + Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S); + Process t4 = startLocalFedWorker(port4); + + + try { + if (!isAlive(t1, t2, t3, t4)) + throw new RuntimeException("Failed starting federated worker"); + rtplatform = execMode; + if (rtplatform == ExecMode.SPARK) { + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + TestConfiguration config = availableTestConfigurations.get(TEST_NAME); + loadTestConfiguration(config); + + // Run reference dml script with normal matrix + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + programArgs = new String[]{"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"), + Boolean.toString(rowPartitioned).toUpperCase(), expected("S")}; + + runTest(null); + + OptimizerUtils.FEDERATED_COMPILATION = activateFedCompilation; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[]{"-stats", "100", "-nvargs", + "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), + "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), + "in_X3=" + TestUtils.federatedAddress(port3, input("X3")), + "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols, + "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")}; + + runTest(null); + + // compare via files + compareResults(0.01, "Stat-DML1", "Stat-DML2"); + + Assert.assertTrue(heavyHittersContainsString("fed_roll")); + + // check that federated input files are still existing + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4"))); + + } finally { + TestUtils.shutdownThreads(t1, t2, t3, t4); + + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + OptimizerUtils.FEDERATED_COMPILATION = false; + } + } +} diff --git a/src/test/scripts/functions/federated/FederatedRollTest.dml b/src/test/scripts/functions/federated/FederatedRollTest.dml new file mode 100644 index 00000000000..cb464256ed8 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedRollTest.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- + # + # Licensed to the Apache Software Foundation (ASF) under one + # or more contributor license agreements. See the NOTICE file + # distributed with this work for additional information + # regarding copyright ownership. The ASF licenses this file + # to you under the Apache License, Version 2.0 (the + # "License"); you may not use this file except in compliance + # with the License. You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, + # software distributed under the License is distributed on an + # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + # KIND, either express or implied. See the License for the + # specific language governing permissions and limitations + # under the License. + # + #------------------------------------------------------------- +if ($rP) { + A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); +} else { + A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2), + list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols))); +} + +s = roll(A, 1); +write(s, $out_S); \ No newline at end of file diff --git a/src/test/scripts/functions/federated/FederatedRollTestReference.dml b/src/test/scripts/functions/federated/FederatedRollTestReference.dml new file mode 100644 index 00000000000..694bd5f1d4a --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedRollTestReference.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- + # + # Licensed to the Apache Software Foundation (ASF) under one + # or more contributor license agreements. See the NOTICE file + # distributed with this work for additional information + # regarding copyright ownership. The ASF licenses this file + # to you under the Apache License, Version 2.0 (the + # "License"); you may not use this file except in compliance + # with the License. You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, + # software distributed under the License is distributed on an + # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + # KIND, either express or implied. See the License for the + # specific language governing permissions and limitations + # under the License. + # + #------------------------------------------------------------- + + if($5) { A = rbind(read($1), read($2), read($3), read($4)); } + else { A = cbind(read($1), read($2), read($3), read($4)); } + + s = roll(A, 1); + write(s, $6); \ No newline at end of file