Skip to content

Commit

Permalink
[SYSTEMDS-3729] Add missing federated roll reorg operations
Browse files Browse the repository at this point in the history
Closes #2126.
  • Loading branch information
min-guk authored and mboehm7 committed Oct 20, 2024
1 parent 80332e0 commit 29b3c61
Show file tree
Hide file tree
Showing 9 changed files with 422 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -406,14 +406,40 @@ public Future<FederatedResponse>[] executeMultipleSlices(long tid, boolean wait,
return ret.toArray(new Future[0]);
}

@SuppressWarnings("unchecked")
public Future<FederatedResponse>[] executeRoll(long tid, boolean wait,
FederatedRequest frEnd, FederatedRequest frStart, long rlen)
{
// executes step1[] - step 2 - ... step4 (only first step federated-data-specific)
setThreadID(tid, new FederatedRequest[]{frStart, frEnd});
List<Future<FederatedResponse>> ret = new ArrayList<>();

for(Pair<FederatedRange, FederatedData> e : _fedMap) {
if (e.getKey().getEndDims()[0] == rlen) {
ret.add(e.getValue().executeFederatedOperation(frEnd));
} else if (e.getKey().getBeginDims()[0] == 0){
ret.add(e.getValue().executeFederatedOperation(frStart));
}
}

// prepare results (future federated responses), with optional wait to ensure the
// order of requests without data dependencies (e.g., cleanup RPCs)
if(wait)
FederationUtils.waitFor(ret);
return (Future<FederatedResponse>[])ret.toArray(new Future[0]);
}

public List<Pair<FederatedRange, Future<FederatedResponse>>> requestFederatedData() {
if(!isInitialized())
throw new DMLRuntimeException("Federated matrix read only supported on initialized FederatedData");

List<Pair<FederatedRange, Future<FederatedResponse>>> readResponses = new ArrayList<>();
FederatedRequest request = new FederatedRequest(RequestType.GET_VAR, _ID);
for(Pair<FederatedRange, FederatedData> e : _fedMap)

for(Pair<FederatedRange, FederatedData> e : _fedMap){
FederatedRequest request = new FederatedRequest(RequestType.GET_VAR, e.getValue().getVarID());
readResponses.add(Pair.of(e.getKey(), e.getValue().executeFederatedOperation(request)));
}

return readResponses;
}

Expand Down Expand Up @@ -692,6 +718,7 @@ public void reverseFedMap() {
}
}


private static class MappingTask implements Callable<Void> {
private final FederatedRange _range;
private final FederatedData _data;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ public class FEDInstructionParser extends InstructionParser
String2FEDInstructionType.put( "r'" , FEDType.Reorg );
String2FEDInstructionType.put( "rdiag" , FEDType.Reorg );
String2FEDInstructionType.put( "rev" , FEDType.Reorg );
String2FEDInstructionType.put( "roll" , FEDType.Reorg );
//String2FEDInstructionType.put( "rshape" , FEDType.Reorg ); Not supported by ReorgFEDInstruction parser!
//String2FEDInstructionType.put( "rsort" , FEDType.Reorg ); Not supported by ReorgFEDInstruction parser!

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand c
* @param istr ?
*/
private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand shift, String opcode, String istr) {
super(CPType.Reorg, op, in, out, opcode, istr);
super(CPType.Reorg, op, in, shift, out, opcode, istr);
_col = null;
_desc = null;
_ixret = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.functionobjects.DiagIndex;
import org.apache.sysds.runtime.functionobjects.RevIndex;
import org.apache.sysds.runtime.functionobjects.RollIndex;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
Expand All @@ -57,6 +59,8 @@
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

public class ReorgFEDInstruction extends UnaryFEDInstruction {
// roll-specific attributes
private CPOperand _shift = null;

public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opcode, String istr, FederatedOutput fedOut) {
super(FEDType.Reorg, op, in1, out, opcode, istr, fedOut);
Expand All @@ -66,14 +70,29 @@ public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, String opc
super(FEDType.Reorg, op, in1, out, opcode, istr);
}

private ReorgFEDInstruction(Operator op, CPOperand in, CPOperand shift, CPOperand out, String opcode, String istr, FederatedOutput fedOut) {
super(FEDType.Reorg, op, in, shift, out, opcode, istr, fedOut);
_shift = shift;
}

public static ReorgFEDInstruction parseInstruction(ReorgCPInstruction rinst) {
return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(),
rinst.getInstructionString(), FederatedOutput.NONE);
if (rinst.input2 != null) {
return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.input2, rinst.output, rinst.getOpcode(),
rinst.getInstructionString(), FederatedOutput.NONE);
} else{
return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(),
rinst.getInstructionString(), FederatedOutput.NONE);
}
}

public static ReorgFEDInstruction parseInstruction(ReorgSPInstruction rinst) {
return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(),
rinst.getInstructionString(), FederatedOutput.NONE);
if (rinst.input2 != null) {
return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.input2, rinst.output, rinst.getOpcode(),
rinst.getInstructionString(), FederatedOutput.NONE);
} else{
return new ReorgFEDInstruction(rinst.getOperator(), rinst.input1, rinst.output, rinst.getOpcode(),
rinst.getInstructionString(), FederatedOutput.NONE);
}
}

public static ReorgFEDInstruction parseInstruction(String str) {
Expand Down Expand Up @@ -105,6 +124,15 @@ else if(opcode.equalsIgnoreCase("rev")) {
return new ReorgFEDInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str,
fedOut);
}
else if (opcode.equalsIgnoreCase("roll")) {
InstructionUtils.checkNumFields(str, 3);
in.split(parts[1]);
out.split(parts[3]);
CPOperand shift = new CPOperand(parts[2]);
fedOut = parseFedOutFlag(str, 3);
return new ReorgFEDInstruction(new ReorgOperator(new RollIndex(0)),
in, out, shift, opcode, str, fedOut);
}
else {
throw new DMLRuntimeException("ReorgFEDInstruction: unsupported opcode: " + opcode);
}
Expand Down Expand Up @@ -167,6 +195,36 @@ else if(instOpcode.equalsIgnoreCase("rev")) {
.setBlocksize(mo1.getBlocksize()).setNonZeros(nnz);
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));

optionalForceLocal(out);
} else if (instOpcode.equalsIgnoreCase("roll")) {
long rlen = mo1.getNumRows();
long shift = ec.getScalarInput(_shift).getLongValue();
shift %= (rlen != 0 ? rlen : 1); // roll matrix with axis=none

long inID = mo1.getFedMapping().getID();
long outEndID = FederationUtils.getNextFedDataID();
long outStartID = FederationUtils.getNextFedDataID();

List<Pair<FederatedRange, FederatedData>> inMap = mo1.getFedMapping().getMap();
Pair<FederationMap, Long> rollResult = rollFedMap(
inMap, inID, outEndID, outStartID, shift, rlen, mo1.getFedMapping().getType());
long length = rollResult.getValue();
FederationMap outFedMap = rollResult.getKey();

FederatedRequest frEnd = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, outEndID,
new ReorgFEDInstruction.SliceMatrix(inID, outEndID, length, true));
FederatedRequest frStart = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, outStartID,
new ReorgFEDInstruction.SliceMatrix(inID, outStartID, length, false));
Future<FederatedResponse>[] ffr = outFedMap.executeRoll(getTID(), true, frEnd, frStart, rlen);

//derive output federated mapping
MatrixObject out = ec.getMatrixObject(output);
long nnz = (mo1.getNnz() != -1) ? mo1.getNnz() : FederationUtils.sumNonZeros(ffr);
out.getDataCharacteristics()
.setDimension(mo1.getNumRows(), mo1.getNumColumns())
.setBlocksize(mo1.getBlocksize())
.setNonZeros(nnz);
out.setFedMapping(outFedMap);
optionalForceLocal(out);
}
else if (instOpcode.equals("rdiag")) {
Expand All @@ -189,6 +247,40 @@ else if (instOpcode.equals("rdiag")) {
}
}


public Pair<FederationMap, Long> rollFedMap(List<Pair<FederatedRange, FederatedData>> oldMap, long inID,
long outEndID, long outStartID, long shift, long rlen, FType type) {
List<Pair<FederatedRange, FederatedData>> map = new ArrayList<>();
long length = 0;

for(Map.Entry<FederatedRange, FederatedData> e : oldMap) {
if(e.getKey().getSize() == 0) continue;
FederatedRange fedRange = new FederatedRange(e.getKey());
long beginRow = fedRange.getBeginDims()[0] + shift;
long endRow = fedRange.getEndDims()[0] + shift;

beginRow = beginRow > rlen ? beginRow - rlen : beginRow;
endRow = endRow > rlen ? endRow - rlen : endRow;

if (beginRow < endRow) {
fedRange.setBeginDim(0, beginRow);
fedRange.setEndDim(0, endRow);
map.add(Pair.of(fedRange, e.getValue().copyWithNewID(inID)));
} else {
length = rlen - beginRow;
fedRange.setBeginDim(0, beginRow);
fedRange.setEndDim(0, rlen);
map.add(Pair.of(fedRange, e.getValue().copyWithNewID(outEndID)));

FederatedRange startRange = new FederatedRange(fedRange);
startRange.setBeginDim(0, 0);
startRange.setEndDim(0, endRow);
map.add(Pair.of(startRange, e.getValue().copyWithNewID(outStartID)));
}
}
return Pair.of(new FederationMap(outEndID, map, type), length);
}

/**
* Update the federated ranges of result and return the updated federation map.
* @param result RdiagResult for which the fedmap is updated
Expand Down Expand Up @@ -307,6 +399,51 @@ private RdiagResult rdiagM2V (MatrixObject mo1, ReorgOperator r_op) {
return new RdiagResult(diagFedMap, dcs);
}

public static class SliceMatrix extends FederatedUDF {
private static final long serialVersionUID = -3466926635958851402L;
private final long _outputID;
private final int _sliceRow;
private final boolean _isRight;

private SliceMatrix(long input, long outputID, long sliceRow, boolean isRight) {
super(new long[] {input});
_outputID = outputID;
_sliceRow = (int) sliceRow;
_isRight = isRight;
}

@Override
public FederatedResponse execute(ExecutionContext ec, Data... data) {
MatrixBlock oriBlock = ((MatrixObject) data[0]).acquireReadAndRelease();
MatrixBlock resBlock;

if (_sliceRow != 0){
if (_isRight){
resBlock = oriBlock.slice(0, _sliceRow-1, 0,
oriBlock.getNumColumns()-1, new MatrixBlock());
} else{
resBlock = oriBlock.slice(_sliceRow, oriBlock.getNumRows()-1,
0, oriBlock.getNumColumns()-1, new MatrixBlock());
}
} else{
resBlock = oriBlock;
}
ec.setMatrixOutput(String.valueOf(_outputID), resBlock);
return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, resBlock);
}

@Override
public List<Long> getOutputIds() {
return new ArrayList<>(Arrays.asList(_outputID));
}

@Override
public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
return Pair.of(String.valueOf(_outputID),
new LineageItem());
}
}

public static class Rdiag extends FederatedUDF {

private static final long serialVersionUID = -3466926635958851402L;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ public static UnaryFEDInstruction parseInstruction(UnaryCPInstruction inst, Exec
}
}
else if(inst instanceof ReorgCPInstruction &&
(inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag") || inst.getOpcode().equals("rev"))) {
(inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag")
|| inst.getOpcode().equals("rev") || inst.getOpcode().equals("roll"))) {
ReorgCPInstruction rinst = (ReorgCPInstruction) inst;
CacheableData<?> mo = ec.getCacheableData(rinst.input1);

Expand Down Expand Up @@ -157,7 +158,8 @@ else if(inst instanceof AggregateUnarySPInstruction) {
return AggregateUnaryFEDInstruction.parseInstruction(auinstruction);
}
else if(inst instanceof ReorgSPInstruction &&
(inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag") || inst.getOpcode().equals("rev"))) {
(inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag")
|| inst.getOpcode().equals("rev") || inst.getOpcode().equals("roll"))) {
ReorgSPInstruction rinst = (ReorgSPInstruction) inst;
CacheableData<?> mo = ec.getCacheableData(rinst.input1);
if((mo instanceof MatrixObject || mo instanceof FrameObject) && mo.isFederated() &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ private ReorgSPInstruction(Operator op, CPOperand in, CPOperand col, CPOperand d
}

private ReorgSPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand shift, String opcode, String istr) {
this(op, in, out, opcode, istr);
super(SPType.Reorg, op, in, shift, null, out, opcode, istr);
_shift = shift;
}

Expand Down
Loading

0 comments on commit 29b3c61

Please sign in to comment.