From c9298438f1c950b160558f3b77e28b73dfee9f8b Mon Sep 17 00:00:00 2001 From: Matthias Boehm Date: Sun, 17 Nov 2024 11:04:54 +0100 Subject: [PATCH] [MINOR] Cleanup code quality (tab formatting, method annotations) --- .../hops/cost/CostEstimatorStaticRuntime.java | 1 - .../sysds/resource/cost/CPCostUtils.java | 1728 ++++++++--------- .../sysds/resource/cost/SparkCostUtils.java | 1528 +++++++-------- .../runtime/functionobjects/RollIndex.java | 52 +- .../compression/TransformPerf.java | 3 - .../component/resource/CPCostUtilsTest.java | 1098 +++++------ .../InstructionsCostEstimatorTest.java | 331 ++-- ...iltinImageSamplePairingLinearizedTest.java | 132 +- .../builtin/part2/BuiltinTSNETest.java | 1 - 9 files changed, 2434 insertions(+), 2440 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java b/src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java index 61c2048836b..cbb0cb3e743 100644 --- a/src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java +++ b/src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java @@ -62,7 +62,6 @@ public class CostEstimatorStaticRuntime extends CostEstimator private static final double DEFAULT_MBS_HDFSWRITE_TEXT_SPARSE = 30; @Override - @SuppressWarnings("unused") protected double getCPInstTimeEstimate( Instruction inst, VarStats[] vs, String[] args ) { CPInstruction cpinst = (CPInstruction)inst; diff --git a/src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java b/src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java index 6d46070112f..7d82422050c 100644 --- a/src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java +++ b/src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java @@ -31,887 +31,887 @@ import static org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType; public class CPCostUtils { - private static final long DEFAULT_NFLOP_NOOP = 10; - private static final long DEFAULT_NFLOP_CP = 1; - private static final long DEFAULT_NFLOP_TEXT_IO = 350; - private static final long DEFAULT_INFERRED_DIM = 1000000; + private static final long DEFAULT_NFLOP_NOOP = 10; + private static final long DEFAULT_NFLOP_CP = 1; + private static final long DEFAULT_NFLOP_TEXT_IO = 350; + private static final long DEFAULT_INFERRED_DIM = 1000000; - public static double getVariableInstTime(VariableCPInstruction inst, VarStats input, VarStats output, IOMetrics metrics) { - long nflop; - switch (inst.getOpcode()) { - case "write": - String fmtStr = inst.getInput3().getLiteral().getStringValue(); - Types.FileFormat fmt = Types.FileFormat.safeValueOf(fmtStr); - long xwrite = fmt.isTextFormat() ? DEFAULT_NFLOP_TEXT_IO : DEFAULT_NFLOP_CP; - nflop = input.getCellsWithSparsity() * xwrite; - break; - case "cast_as_matrix": - case "cast_as_frame": - nflop = input.getCells(); - break; - case "rmfilevar": case "attachfiletovar": case "setfilename": - throw new RuntimeException("Undefined behaviour for instruction with opcode: " + inst.getOpcode()); - default: - // negligibly low number of FLOP (independent of variables' dimensions) - return 0; - } - // assignOutputMemoryStats() needed only for casts - return getCPUTime(nflop, metrics, output, input); - } + public static double getVariableInstTime(VariableCPInstruction inst, VarStats input, VarStats output, IOMetrics metrics) { + long nflop; + switch (inst.getOpcode()) { + case "write": + String fmtStr = inst.getInput3().getLiteral().getStringValue(); + Types.FileFormat fmt = Types.FileFormat.safeValueOf(fmtStr); + long xwrite = fmt.isTextFormat() ? DEFAULT_NFLOP_TEXT_IO : DEFAULT_NFLOP_CP; + nflop = input.getCellsWithSparsity() * xwrite; + break; + case "cast_as_matrix": + case "cast_as_frame": + nflop = input.getCells(); + break; + case "rmfilevar": case "attachfiletovar": case "setfilename": + throw new RuntimeException("Undefined behaviour for instruction with opcode: " + inst.getOpcode()); + default: + // negligibly low number of FLOP (independent of variables' dimensions) + return 0; + } + // assignOutputMemoryStats() needed only for casts + return getCPUTime(nflop, metrics, output, input); + } - public static double getDataGenCPInstTime(UnaryCPInstruction inst, VarStats output, IOMetrics metrics) { - long nflop; - String opcode = inst.getOpcode(); - if( inst instanceof DataGenCPInstruction) { - if (opcode.equals("rand") || opcode.equals("frame")) { - DataGenCPInstruction rinst = (DataGenCPInstruction) inst; - if( rinst.getMinValue() == 0.0 && rinst.getMaxValue() == 0.0 ) - nflop = 0; // empty matrix - else if( rinst.getSparsity() == 1.0 && rinst.getMinValue() == rinst.getMaxValue() ) // allocate, array fill - nflop = 8 * output.getCells(); - else { // full rand - if (rinst.getSparsity() == 1.0) - nflop = 32 * output.getCells() + 8 * output.getCells(); // DENSE gen (incl allocate) - else if (rinst.getSparsity() < MatrixBlock.SPARSITY_TURN_POINT) - nflop = 3 * output.getCellsWithSparsity() + 24 * output.getCellsWithSparsity(); //SPARSE gen (incl allocate) - else - nflop = 2 * output.getCells() + 8 * output.getCells(); // DENSE gen (incl allocate) - } - } else if (opcode.equals(DataGen.SEQ_OPCODE)) { - nflop = DEFAULT_NFLOP_CP * output.getCells(); - } else { - // DataGen.SAMPLE_OPCODE, DataGen.TIME_OPCODE, - throw new RuntimeException("Undefined behaviour for instruction with opcode: " + inst.getOpcode()); - } - } - else if( inst instanceof StringInitCPInstruction) { - nflop = DEFAULT_NFLOP_CP * output.getCells(); - } else { - throw new IllegalArgumentException("Method has been called with invalid instruction: " + inst); - } - return getCPUTime(nflop, metrics, output); - } + public static double getDataGenCPInstTime(UnaryCPInstruction inst, VarStats output, IOMetrics metrics) { + long nflop; + String opcode = inst.getOpcode(); + if( inst instanceof DataGenCPInstruction) { + if (opcode.equals("rand") || opcode.equals("frame")) { + DataGenCPInstruction rinst = (DataGenCPInstruction) inst; + if( rinst.getMinValue() == 0.0 && rinst.getMaxValue() == 0.0 ) + nflop = 0; // empty matrix + else if( rinst.getSparsity() == 1.0 && rinst.getMinValue() == rinst.getMaxValue() ) // allocate, array fill + nflop = 8 * output.getCells(); + else { // full rand + if (rinst.getSparsity() == 1.0) + nflop = 32 * output.getCells() + 8 * output.getCells(); // DENSE gen (incl allocate) + else if (rinst.getSparsity() < MatrixBlock.SPARSITY_TURN_POINT) + nflop = 3 * output.getCellsWithSparsity() + 24 * output.getCellsWithSparsity(); //SPARSE gen (incl allocate) + else + nflop = 2 * output.getCells() + 8 * output.getCells(); // DENSE gen (incl allocate) + } + } else if (opcode.equals(DataGen.SEQ_OPCODE)) { + nflop = DEFAULT_NFLOP_CP * output.getCells(); + } else { + // DataGen.SAMPLE_OPCODE, DataGen.TIME_OPCODE, + throw new RuntimeException("Undefined behaviour for instruction with opcode: " + inst.getOpcode()); + } + } + else if( inst instanceof StringInitCPInstruction) { + nflop = DEFAULT_NFLOP_CP * output.getCells(); + } else { + throw new IllegalArgumentException("Method has been called with invalid instruction: " + inst); + } + return getCPUTime(nflop, metrics, output); + } - public static double getUnaryInstTime(UnaryCPInstruction inst, VarStats input, VarStats weights, VarStats output, IOMetrics metrics) { - if (inst instanceof UaggOuterChainCPInstruction || inst instanceof DnnCPInstruction) { - throw new RuntimeException("Time estimation for CP instruction of class " + inst.getClass().getName() + "not supported yet"); - } - // CPType = Unary/Builtin - CPType instructionType = inst.getCPInstructionType(); - String opcode = inst.getOpcode(); + public static double getUnaryInstTime(UnaryCPInstruction inst, VarStats input, VarStats weights, VarStats output, IOMetrics metrics) { + if (inst instanceof UaggOuterChainCPInstruction || inst instanceof DnnCPInstruction) { + throw new RuntimeException("Time estimation for CP instruction of class " + inst.getClass().getName() + "not supported yet"); + } + // CPType = Unary/Builtin + CPType instructionType = inst.getCPInstructionType(); + String opcode = inst.getOpcode(); - boolean includeWeights = false; - if (inst instanceof MMTSJCPInstruction) { - MMTSJ.MMTSJType type = ((MMTSJCPInstruction) inst).getMMTSJType(); - opcode += type.isLeft() ? "_left" : "_right"; - } else if (inst instanceof ReorgCPInstruction && opcode.equals("rsort")) { - if (inst.input2 != null) includeWeights = true; - } else if (inst instanceof QuantileSortCPInstruction) { - if (inst.input2 != null) { - opcode += "_wts"; - includeWeights = true; - } - } else if (inst instanceof CentralMomentCPInstruction) { - CMOperator.AggregateOperationTypes opType = ((CMOperator) inst.getOperator()).getAggOpType(); - opcode += "_" + opType.name().toLowerCase(); - if (inst.input2 != null) { - includeWeights = true; - } - } - long nflop = getInstNFLOP(instructionType, opcode, output, input); - if (includeWeights) - return getCPUTime(nflop, metrics, output, input, weights); - return getCPUTime(nflop, metrics, output, input); - } + boolean includeWeights = false; + if (inst instanceof MMTSJCPInstruction) { + MMTSJ.MMTSJType type = ((MMTSJCPInstruction) inst).getMMTSJType(); + opcode += type.isLeft() ? "_left" : "_right"; + } else if (inst instanceof ReorgCPInstruction && opcode.equals("rsort")) { + if (inst.input2 != null) includeWeights = true; + } else if (inst instanceof QuantileSortCPInstruction) { + if (inst.input2 != null) { + opcode += "_wts"; + includeWeights = true; + } + } else if (inst instanceof CentralMomentCPInstruction) { + CMOperator.AggregateOperationTypes opType = ((CMOperator) inst.getOperator()).getAggOpType(); + opcode += "_" + opType.name().toLowerCase(); + if (inst.input2 != null) { + includeWeights = true; + } + } + long nflop = getInstNFLOP(instructionType, opcode, output, input); + if (includeWeights) + return getCPUTime(nflop, metrics, output, input, weights); + return getCPUTime(nflop, metrics, output, input); + } - public static double getBinaryInstTime(BinaryCPInstruction inst, VarStats input1, VarStats input2, VarStats weights, VarStats output, IOMetrics metrics) { - // CPType = Binary/Builtin - CPType instructionType = inst.getCPInstructionType(); - String opcode = inst.getOpcode(); + public static double getBinaryInstTime(BinaryCPInstruction inst, VarStats input1, VarStats input2, VarStats weights, VarStats output, IOMetrics metrics) { + // CPType = Binary/Builtin + CPType instructionType = inst.getCPInstructionType(); + String opcode = inst.getOpcode(); - boolean includeWeights = false; - if (inst instanceof CovarianceCPInstruction) { // cov - includeWeights = true; - } else if (inst instanceof QuantilePickCPInstruction) { - PickByCount.OperationTypes opType = ((QuantilePickCPInstruction) inst).getOperationType(); - opcode += "_" + opType.name().toLowerCase(); - } else if (inst instanceof AggregateBinaryCPInstruction) { - AggregateBinaryCPInstruction abinst = (AggregateBinaryCPInstruction) inst; - opcode += abinst.transposeLeft? "_tl": ""; - opcode += abinst.transposeRight? "_tr": ""; - } - long nflop = getInstNFLOP(instructionType, opcode, output, input1, input2); - if (includeWeights) - return getCPUTime(nflop, metrics, output, input1, input2, weights); - return getCPUTime(nflop, metrics, output, input1, input2); - } + boolean includeWeights = false; + if (inst instanceof CovarianceCPInstruction) { // cov + includeWeights = true; + } else if (inst instanceof QuantilePickCPInstruction) { + PickByCount.OperationTypes opType = ((QuantilePickCPInstruction) inst).getOperationType(); + opcode += "_" + opType.name().toLowerCase(); + } else if (inst instanceof AggregateBinaryCPInstruction) { + AggregateBinaryCPInstruction abinst = (AggregateBinaryCPInstruction) inst; + opcode += abinst.transposeLeft? "_tl": ""; + opcode += abinst.transposeRight? "_tr": ""; + } + long nflop = getInstNFLOP(instructionType, opcode, output, input1, input2); + if (includeWeights) + return getCPUTime(nflop, metrics, output, input1, input2, weights); + return getCPUTime(nflop, metrics, output, input1, input2); + } - public static double getComputationInstTime(ComputationCPInstruction inst, VarStats input1, VarStats input2, VarStats input3, VarStats input4, VarStats output, IOMetrics metrics) { - if (inst instanceof UnaryCPInstruction || inst instanceof BinaryCPInstruction) { - throw new RuntimeException("Instructions of type UnaryCPInstruction and BinaryCPInstruction are not handled by this method"); - } - CPType instructionType = inst.getCPInstructionType(); - String opcode = inst.getOpcode(); + public static double getComputationInstTime(ComputationCPInstruction inst, VarStats input1, VarStats input2, VarStats input3, VarStats input4, VarStats output, IOMetrics metrics) { + if (inst instanceof UnaryCPInstruction || inst instanceof BinaryCPInstruction) { + throw new RuntimeException("Instructions of type UnaryCPInstruction and BinaryCPInstruction are not handled by this method"); + } + CPType instructionType = inst.getCPInstructionType(); + String opcode = inst.getOpcode(); - // CURRENTLY: 2 is the maximum number of needed input stats objects for NFLOP estimation - long nflop = getInstNFLOP(instructionType, opcode, output, input1, input2); - return getCPUTime(nflop, metrics, output, input1, input2, input3, input4); - } + // CURRENTLY: 2 is the maximum number of needed input stats objects for NFLOP estimation + long nflop = getInstNFLOP(instructionType, opcode, output, input1, input2); + return getCPUTime(nflop, metrics, output, input1, input2, input3, input4); + } - public static double getBuiltinNaryInstTime(BuiltinNaryCPInstruction inst, VarStats[] inputs, VarStats output, IOMetrics metrics) { - CPType instructionType = inst.getCPInstructionType(); - String opcode = inst.getOpcode(); - long nflop; - if (inputs == null) { - nflop = getInstNFLOP(instructionType, opcode, output); - return getCPUTime(nflop, metrics, output); - } - nflop = getInstNFLOP(instructionType, opcode, output, inputs); - return getCPUTime(nflop, metrics, output, inputs); - } + public static double getBuiltinNaryInstTime(BuiltinNaryCPInstruction inst, VarStats[] inputs, VarStats output, IOMetrics metrics) { + CPType instructionType = inst.getCPInstructionType(); + String opcode = inst.getOpcode(); + long nflop; + if (inputs == null) { + nflop = getInstNFLOP(instructionType, opcode, output); + return getCPUTime(nflop, metrics, output); + } + nflop = getInstNFLOP(instructionType, opcode, output, inputs); + return getCPUTime(nflop, metrics, output, inputs); + } - public static double getParameterizedBuiltinInstTime(ParameterizedBuiltinCPInstruction inst, VarStats input, VarStats output, IOMetrics metrics) { - CPType instructionType = inst.getCPInstructionType(); - String opcode = inst.getOpcode(); - if (opcode.equals("rmempty")) { - String margin = inst.getParameterMap().get("margin"); - opcode += "_" + margin; - } else if (opcode.equals("groupedagg")) { - CMOperator.AggregateOperationTypes opType = ((CMOperator) inst.getOperator()).getAggOpType(); - opcode += "_" + opType.name().toLowerCase(); - } - long nflop = getInstNFLOP(instructionType, opcode, output, input); - return getCPUTime(nflop, metrics, output, input); - } + public static double getParameterizedBuiltinInstTime(ParameterizedBuiltinCPInstruction inst, VarStats input, VarStats output, IOMetrics metrics) { + CPType instructionType = inst.getCPInstructionType(); + String opcode = inst.getOpcode(); + if (opcode.equals("rmempty")) { + String margin = inst.getParameterMap().get("margin"); + opcode += "_" + margin; + } else if (opcode.equals("groupedagg")) { + CMOperator.AggregateOperationTypes opType = ((CMOperator) inst.getOperator()).getAggOpType(); + opcode += "_" + opType.name().toLowerCase(); + } + long nflop = getInstNFLOP(instructionType, opcode, output, input); + return getCPUTime(nflop, metrics, output, input); + } - public static double getMultiReturnBuiltinInstTime(MultiReturnBuiltinCPInstruction inst, VarStats input, VarStats[] outputs, IOMetrics metrics) { - CPType instructionType = inst.getCPInstructionType(); - String opcode = inst.getOpcode(); - long nflop = getInstNFLOP(instructionType, opcode, outputs[0], input); - double time = getCPUTime(nflop, metrics, outputs[0], input); - for (int i = 1; i < outputs.length; i++) { - time += IOCostUtils.getMemWriteTime(outputs[i], metrics); - } - return time; - } + public static double getMultiReturnBuiltinInstTime(MultiReturnBuiltinCPInstruction inst, VarStats input, VarStats[] outputs, IOMetrics metrics) { + CPType instructionType = inst.getCPInstructionType(); + String opcode = inst.getOpcode(); + long nflop = getInstNFLOP(instructionType, opcode, outputs[0], input); + double time = getCPUTime(nflop, metrics, outputs[0], input); + for (int i = 1; i < outputs.length; i++) { + time += IOCostUtils.getMemWriteTime(outputs[i], metrics); + } + return time; + } - // HELPERS - public static void assignOutputMemoryStats(CPInstruction inst, VarStats output, VarStats...inputs) { - CPType instType = inst.getCPInstructionType(); - String opcode = inst.getOpcode(); + // HELPERS + public static void assignOutputMemoryStats(CPInstruction inst, VarStats output, VarStats...inputs) { + CPType instType = inst.getCPInstructionType(); + String opcode = inst.getOpcode(); - if (inst instanceof MultiReturnBuiltinCPInstruction) { - boolean inferred = false; - for (VarStats current : inputs) { - if (!inferred && current.getCells() < 0) { - inferStats(instType, opcode, output, inputs); - inferred = true; - } - if (current.getCells() < 0) { - throw new RuntimeException("Operation of type MultiReturnBuiltin with opcode '" + opcode + "' has incomplete formula for inferring dimensions"); - } - current.allocatedMemory = OptimizerUtils.estimateSizeExactSparsity(current.characteristics); - } - return; - } else if (output.getCells() < 0) { - inferStats(instType, opcode, output, inputs); - } - output.allocatedMemory = output.isScalar()? 1 : OptimizerUtils.estimateSizeExactSparsity(output.characteristics); - } + if (inst instanceof MultiReturnBuiltinCPInstruction) { + boolean inferred = false; + for (VarStats current : inputs) { + if (!inferred && current.getCells() < 0) { + inferStats(instType, opcode, output, inputs); + inferred = true; + } + if (current.getCells() < 0) { + throw new RuntimeException("Operation of type MultiReturnBuiltin with opcode '" + opcode + "' has incomplete formula for inferring dimensions"); + } + current.allocatedMemory = OptimizerUtils.estimateSizeExactSparsity(current.characteristics); + } + return; + } else if (output.getCells() < 0) { + inferStats(instType, opcode, output, inputs); + } + output.allocatedMemory = output.isScalar()? 1 : OptimizerUtils.estimateSizeExactSparsity(output.characteristics); + } - public static void inferStats(CPType instType, String opcode, VarStats output, VarStats...inputs) { - switch (instType) { - case Unary: - case Builtin: - copyMissingDim(output, inputs[0]); - break; - case AggregateUnary: - if (opcode.startsWith("uar")) { - copyMissingDim(output, inputs[0].getM(), 1); - } else if (opcode.startsWith("uac")) { - copyMissingDim(output, 1, inputs[0].getN()); - } else { - copyMissingDim(output, 1, 1); - } - break; - case MatrixIndexing: - if (opcode.equals("rightIndex")) { - long rowLower = (inputs[2].varName.matches("\\d+") ? Long.parseLong(inputs[2].varName) : -1); - long rowUpper = (inputs[3].varName.matches("\\d+") ? Long.parseLong(inputs[3].varName) : -1); - long colLower = (inputs[4].varName.matches("\\d+") ? Long.parseLong(inputs[4].varName) : -1); - long colUpper = (inputs[5].varName.matches("\\d+") ? Long.parseLong(inputs[5].varName) : -1); + public static void inferStats(CPType instType, String opcode, VarStats output, VarStats...inputs) { + switch (instType) { + case Unary: + case Builtin: + copyMissingDim(output, inputs[0]); + break; + case AggregateUnary: + if (opcode.startsWith("uar")) { + copyMissingDim(output, inputs[0].getM(), 1); + } else if (opcode.startsWith("uac")) { + copyMissingDim(output, 1, inputs[0].getN()); + } else { + copyMissingDim(output, 1, 1); + } + break; + case MatrixIndexing: + if (opcode.equals("rightIndex")) { + long rowLower = (inputs[2].varName.matches("\\d+") ? Long.parseLong(inputs[2].varName) : -1); + long rowUpper = (inputs[3].varName.matches("\\d+") ? Long.parseLong(inputs[3].varName) : -1); + long colLower = (inputs[4].varName.matches("\\d+") ? Long.parseLong(inputs[4].varName) : -1); + long colUpper = (inputs[5].varName.matches("\\d+") ? Long.parseLong(inputs[5].varName) : -1); - long rowRange; - { - if (rowLower > 0 && rowUpper > 0) rowRange = rowUpper - rowLower + 1; - else if (inputs[2].varName.equals(inputs[3].varName)) rowRange = 1; - else - rowRange = inputs[0].getM() > 0 ? inputs[0].getM() : DEFAULT_INFERRED_DIM; - } - long colRange; - { - if (colLower > 0 && colUpper > 0) colRange = colUpper - colLower + 1; - else if (inputs[4].varName.equals(inputs[5].varName)) colRange = 1; - else - colRange = inputs[0].getM() > 0 ? inputs[0].getN() : DEFAULT_INFERRED_DIM; - } - copyMissingDim(output, rowRange, colRange); - } else { // leftIndex - copyMissingDim(output, inputs[0]); - } - break; - case Reorg: - switch (opcode) { - case "r'": - copyMissingDim(output, inputs[0].getN(), inputs[0].getM()); - break; - case "rev": - copyMissingDim(output, inputs[0]); - break; - case "rdiag": - if (inputs[0].getN() == 1) // diagV2M - copyMissingDim(output, inputs[0].getM(), inputs[0].getM()); - else // diagM2V - copyMissingDim(output, inputs[0].getM(), 1); - break; - case "rsort": - boolean ixRet = Boolean.parseBoolean(inputs[1].varName); - if (ixRet) - copyMissingDim(output, inputs[0].getM(), 1); - else - copyMissingDim(output, inputs[0]); - break; - } - break; - case Binary: - // handle case of matrix-scalar op. with the matrix being the second operand - VarStats origin = inputs[0].isScalar()? inputs[1] : inputs[0]; - copyMissingDim(output, origin); - break; - case AggregateBinary: - boolean transposeLeft = false; - boolean transposeRight = false; - if (inputs.length == 4) { - transposeLeft = inputs[2] != null && Boolean.parseBoolean(inputs[2].varName); - transposeRight = inputs[3] != null && Boolean.parseBoolean(inputs[3].varName); - } - if (transposeLeft && transposeRight) - copyMissingDim(output, inputs[0].getM(), inputs[1].getM()); - else if (transposeLeft) - copyMissingDim(output, inputs[0].getM(), inputs[1].getN()); - else if (transposeRight) - copyMissingDim(output, inputs[0].getN(), inputs[1].getN()); - else - copyMissingDim(output, inputs[0].getN(), inputs[1].getM()); - break; - case ParameterizedBuiltin: - if (opcode.equals("rmempty") || opcode.equals("replace")) { - copyMissingDim(output, inputs[0]); - } else if (opcode.equals("uppertri") || opcode.equals("lowertri")) { - copyMissingDim(output, inputs[0].getM(), inputs[0].getM()); - } - break; - case Rand: - // inferring missing output dimensions is handled exceptionally here - if (output.getCells() < 0) { - long nrows = (inputs[0].varName.matches("\\d+") ? Long.parseLong(inputs[0].varName) : -1); - long ncols = (inputs[1].varName.matches("\\d+") ? Long.parseLong(inputs[1].varName) : -1); - copyMissingDim(output, nrows, ncols); - } - break; - case Ctable: - long m = (inputs[2].varName.matches("\\d+") ? Long.parseLong(inputs[2].varName) : -1); - long n = (inputs[3].varName.matches("\\d+") ? Long.parseLong(inputs[3].varName) : -1); - if (inputs[1].isScalar()) {// Histogram - if (m < 0) m = inputs[0].getM(); - if (n < 0) n = 1; - copyMissingDim(output, m, n); - } else { // transform (including "ctableexpand") - if (m < 0) m = inputs[0].getM(); - if (n < 0) n = inputs[1].getCells(); // NOTE: very generous assumption, it could be revised; - copyMissingDim(output, m, n); - } - break; - case MultiReturnBuiltin: - // special case: output and inputs stats arguments are swapped: always single input with multiple outputs - VarStats FirstStats = inputs[0]; - VarStats SecondStats = inputs[1]; - switch (opcode) { - case "qr": - copyMissingDim(FirstStats, output.getM(), output.getM()); // Q - copyMissingDim(SecondStats, output.getM(), output.getN()); // R - break; - case "lu": - copyMissingDim(FirstStats, output.getN(), output.getN()); // L - copyMissingDim(SecondStats, output.getN(), output.getN()); // U - break; - case "eigen": - copyMissingDim(FirstStats, output.getN(), 1); // values - copyMissingDim(SecondStats, output.getN(), output.getN()); // vectors - break; - // not all opcodes supported yet - } - break; - default: - throw new RuntimeException("Operation of type "+instType+" with opcode '"+opcode+"' has no formula for inferring dimensions"); - } - if (output.getCells() < 0) { - throw new RuntimeException("Operation of type "+instType+" with opcode '"+opcode+"' has incomplete formula for inferring dimensions"); - } - if (output.getNNZ() < 0) { - output.characteristics.setNonZeros(output.getCells()); - } - } + long rowRange; + { + if (rowLower > 0 && rowUpper > 0) rowRange = rowUpper - rowLower + 1; + else if (inputs[2].varName.equals(inputs[3].varName)) rowRange = 1; + else + rowRange = inputs[0].getM() > 0 ? inputs[0].getM() : DEFAULT_INFERRED_DIM; + } + long colRange; + { + if (colLower > 0 && colUpper > 0) colRange = colUpper - colLower + 1; + else if (inputs[4].varName.equals(inputs[5].varName)) colRange = 1; + else + colRange = inputs[0].getM() > 0 ? inputs[0].getN() : DEFAULT_INFERRED_DIM; + } + copyMissingDim(output, rowRange, colRange); + } else { // leftIndex + copyMissingDim(output, inputs[0]); + } + break; + case Reorg: + switch (opcode) { + case "r'": + copyMissingDim(output, inputs[0].getN(), inputs[0].getM()); + break; + case "rev": + copyMissingDim(output, inputs[0]); + break; + case "rdiag": + if (inputs[0].getN() == 1) // diagV2M + copyMissingDim(output, inputs[0].getM(), inputs[0].getM()); + else // diagM2V + copyMissingDim(output, inputs[0].getM(), 1); + break; + case "rsort": + boolean ixRet = Boolean.parseBoolean(inputs[1].varName); + if (ixRet) + copyMissingDim(output, inputs[0].getM(), 1); + else + copyMissingDim(output, inputs[0]); + break; + } + break; + case Binary: + // handle case of matrix-scalar op. with the matrix being the second operand + VarStats origin = inputs[0].isScalar()? inputs[1] : inputs[0]; + copyMissingDim(output, origin); + break; + case AggregateBinary: + boolean transposeLeft = false; + boolean transposeRight = false; + if (inputs.length == 4) { + transposeLeft = inputs[2] != null && Boolean.parseBoolean(inputs[2].varName); + transposeRight = inputs[3] != null && Boolean.parseBoolean(inputs[3].varName); + } + if (transposeLeft && transposeRight) + copyMissingDim(output, inputs[0].getM(), inputs[1].getM()); + else if (transposeLeft) + copyMissingDim(output, inputs[0].getM(), inputs[1].getN()); + else if (transposeRight) + copyMissingDim(output, inputs[0].getN(), inputs[1].getN()); + else + copyMissingDim(output, inputs[0].getN(), inputs[1].getM()); + break; + case ParameterizedBuiltin: + if (opcode.equals("rmempty") || opcode.equals("replace")) { + copyMissingDim(output, inputs[0]); + } else if (opcode.equals("uppertri") || opcode.equals("lowertri")) { + copyMissingDim(output, inputs[0].getM(), inputs[0].getM()); + } + break; + case Rand: + // inferring missing output dimensions is handled exceptionally here + if (output.getCells() < 0) { + long nrows = (inputs[0].varName.matches("\\d+") ? Long.parseLong(inputs[0].varName) : -1); + long ncols = (inputs[1].varName.matches("\\d+") ? Long.parseLong(inputs[1].varName) : -1); + copyMissingDim(output, nrows, ncols); + } + break; + case Ctable: + long m = (inputs[2].varName.matches("\\d+") ? Long.parseLong(inputs[2].varName) : -1); + long n = (inputs[3].varName.matches("\\d+") ? Long.parseLong(inputs[3].varName) : -1); + if (inputs[1].isScalar()) {// Histogram + if (m < 0) m = inputs[0].getM(); + if (n < 0) n = 1; + copyMissingDim(output, m, n); + } else { // transform (including "ctableexpand") + if (m < 0) m = inputs[0].getM(); + if (n < 0) n = inputs[1].getCells(); // NOTE: very generous assumption, it could be revised; + copyMissingDim(output, m, n); + } + break; + case MultiReturnBuiltin: + // special case: output and inputs stats arguments are swapped: always single input with multiple outputs + VarStats FirstStats = inputs[0]; + VarStats SecondStats = inputs[1]; + switch (opcode) { + case "qr": + copyMissingDim(FirstStats, output.getM(), output.getM()); // Q + copyMissingDim(SecondStats, output.getM(), output.getN()); // R + break; + case "lu": + copyMissingDim(FirstStats, output.getN(), output.getN()); // L + copyMissingDim(SecondStats, output.getN(), output.getN()); // U + break; + case "eigen": + copyMissingDim(FirstStats, output.getN(), 1); // values + copyMissingDim(SecondStats, output.getN(), output.getN()); // vectors + break; + // not all opcodes supported yet + } + break; + default: + throw new RuntimeException("Operation of type "+instType+" with opcode '"+opcode+"' has no formula for inferring dimensions"); + } + if (output.getCells() < 0) { + throw new RuntimeException("Operation of type "+instType+" with opcode '"+opcode+"' has incomplete formula for inferring dimensions"); + } + if (output.getNNZ() < 0) { + output.characteristics.setNonZeros(output.getCells()); + } + } - private static void copyMissingDim(VarStats target, long originRows, long originCols) { - if (target.getM() < 0) - target.characteristics.setRows(originRows); - if (target.getN() < 0) - target.characteristics.setCols(originCols); - } + private static void copyMissingDim(VarStats target, long originRows, long originCols) { + if (target.getM() < 0) + target.characteristics.setRows(originRows); + if (target.getN() < 0) + target.characteristics.setCols(originCols); + } - private static void copyMissingDim(VarStats target, VarStats origin) { - if (target.getM() < 0) - target.characteristics.setRows(origin.getM()); - if (target.getN() < 0) - target.characteristics.setCols(origin.getN()); - } + private static void copyMissingDim(VarStats target, VarStats origin) { + if (target.getM() < 0) + target.characteristics.setRows(origin.getM()); + if (target.getN() < 0) + target.characteristics.setCols(origin.getN()); + } - public static double getCPUTime(long nflop, IOCostUtils.IOMetrics driverMetrics, VarStats output, VarStats...inputs) { - double memScanTime = 0; - for (VarStats input: inputs) { - if (input == null) continue; - memScanTime += IOCostUtils.getMemReadTime(input, driverMetrics); - } - double cpuComputationTime = (double) nflop / driverMetrics.cpuFLOPS; - double memWriteTime = output != null? IOCostUtils.getMemWriteTime(output, driverMetrics) : 0; - return Math.max(memScanTime, cpuComputationTime) + memWriteTime; - } + public static double getCPUTime(long nflop, IOCostUtils.IOMetrics driverMetrics, VarStats output, VarStats...inputs) { + double memScanTime = 0; + for (VarStats input: inputs) { + if (input == null) continue; + memScanTime += IOCostUtils.getMemReadTime(input, driverMetrics); + } + double cpuComputationTime = (double) nflop / driverMetrics.cpuFLOPS; + double memWriteTime = output != null? IOCostUtils.getMemWriteTime(output, driverMetrics) : 0; + return Math.max(memScanTime, cpuComputationTime) + memWriteTime; + } - /** - * - * @param instructionType instruction type - * @param opcode instruction opcode, potentially with suffix to mark an extra op. characteristic - * @param output output's variable statistics, null is not needed for the estimation - * @param inputs any inputs' variable statistics, no object passed is not needed for estimation - * @return estimated number of floating point operations - */ - public static long getInstNFLOP( - CPType instructionType, - String opcode, - VarStats output, - VarStats...inputs - ) { - opcode = opcode.toLowerCase(); // enforce lowercase for convince - long m; - double costs = 0; - switch (instructionType) { - // types corresponding to UnaryCPInstruction - case Unary: - case Builtin: // log and log_nz only - if (output == null || inputs.length < 1) - throw new RuntimeException("Not all required arguments for Unary/Builtin operations are passed initialized"); - double sparsity = inputs[0].getSparsity(); - switch (opcode) { - case "!": - case "isna": - case "isnan": - case "isinf": - case "ceil": - case "floor": - costs = 1; - break; - case "abs": - case "round": - case "sign": - costs = 1 * sparsity; - break; - case "sprop": - case "sqrt": - costs = 2 * sparsity; - break; - case "exp": - costs = 18 * sparsity; - break; - case "sigmoid": - costs = 21 * sparsity; - break; - case "log": - costs = 32; - break; - case "log_nz": - case "plogp": - costs = 32 * sparsity; - break; - case "print": - case "assert": - costs = 1; - break; - case "sin": - costs = 18 * sparsity; - break; - case "cos": - costs = 22 * inputs[0].getSparsity(); - break; - case "tan": - costs = 42 * inputs[0].getSparsity(); - break; - case "asin": - case "sinh": - costs = 93; - break; - case "acos": - case "cosh": - costs = 103; - break; - case "atan": - case "tanh": - costs = 40; - break; - case "ucumk+": - case "ucummin": - case "ucummax": - case "ucum*": - costs = 1 * sparsity; - break; - case "ucumk+*": - costs = 2 * sparsity; - break; - case "stop": - costs = 0; - break; - case "typeof": - costs = 1; - break; - case "inverse": - costs = (4.0 / 3.0) * output.getCellsWithSparsity() * output.getCellsWithSparsity(); - break; - case "cholesky": - costs = (1.0 / 3.0) * output.getCellsWithSparsity() * output.getCellsWithSparsity(); - break; - case "detectschema": - case "colnames": - throw new RuntimeException("Specific Frame operation with opcode '" + opcode + "' is not supported yet"); - default: - // at the point of implementation no further supported operations - throw new DMLRuntimeException("Unary operation with opcode '" + opcode + "' is not supported by SystemDS"); - } - return (long) (costs * output.getCells()); - case AggregateUnary: - if (output == null || inputs.length < 1) - throw new RuntimeException("Not all required arguments for AggregateUnary operations are passed initialized"); - switch (opcode) { - case "nrow": - case "ncol": - case "length": - case "exists": - case "lineage": - return DEFAULT_NFLOP_NOOP; - case "uak+": - case "uark+": - case "uack+": - costs = 4; - break; - case "uasqk+": - case "uarsqk+": - case "uacsqk+": - costs = 5; - break; - case "uamean": - case "uarmean": - case "uacmean": - costs = 7; - break; - case "uavar": - case "uarvar": - case "uacvar": - costs = 14; - break; - case "uamax": - case "uarmax": - case "uarimax": - case "uacmax": - case "uamin": - case "uarmin": - case "uarimin": - case "uacmin": - costs = 1; - break; - case "ua+": - case "uar+": - case "uac+": - case "ua*": - case "uar*": - case "uac*": - costs = 1 * output.getSparsity(); - break; - // count distinct operations - case "uacd": - case "uacdr": - case "uacdc": - case "unique": - case "uniquer": - case "uniquec": - costs = 1 * output.getSparsity(); - break; - case "uacdap": - case "uacdapr": - case "uacdapc": - costs = 0.5 * output.getSparsity(); // do not iterate through all the cells - break; - // aggregation over the diagonal of a square matrix - case "uatrace": - case "uaktrace": - return inputs[0].getM(); - default: - // at the point of implementation no further supported operations - throw new DMLRuntimeException("AggregateUnary operation with opcode '" + opcode + "' is not supported by SystemDS"); - } - // scale - if (opcode.startsWith("uar")) { - costs *= inputs[0].getM(); - } else if (opcode.startsWith("uac")) { - costs *= inputs[0].getN(); - } else { - costs *= inputs[0].getCells(); - } - return (long) (costs * output.getCells()); - case MMTSJ: - if (inputs.length < 1) - throw new RuntimeException("Not all required arguments for MMTSJ operations are passed initialized"); - // reduce by factor of 4: matrix multiplication better than average FLOP count - // + multiply only upper triangular - if (opcode.equals("tsmm_left")) { - costs = inputs[0].getN() * (inputs[0].getSparsity() / 2); - } else { // tsmm/tsmm_right - costs = inputs[0].getM() * (inputs[0].getSparsity() / 2); - } - return (long) (costs * inputs[0].getCellsWithSparsity()); - case Reorg: - case Reshape: - if (output == null) - throw new RuntimeException("Not all required arguments for Reorg/Reshape operations are passed initialized"); - if (opcode.equals("rsort")) - return (long) (output.getCellsWithSparsity() * (Math.log(output.getM()) / Math.log(2))); // merge sort columns (n*m*log2(m)) - return output.getCellsWithSparsity(); - case MatrixIndexing: - if (output == null) - throw new RuntimeException("Not all required arguments for Indexing operations are passed initialized"); - return output.getCellsWithSparsity(); - case MMChain: - if (inputs.length < 1) - throw new RuntimeException("Not all required arguments for MMChain operations are passed initialized"); - // reduction by factor 2 because matrix mult better than average flop count - // (mmchain essentially two matrix-vector muliplications) - return (2 + 2) * inputs[0].getCellsWithSparsity() / 2; - case QSort: - if (inputs.length < 1) - throw new RuntimeException("Not all required arguments for QSort operations are passed initialized"); - // mergesort since comparator used - m = inputs[0].getM(); - if (opcode.equals("qsort")) - costs = m + m; - else // == "qsort_wts" (with weights) - costs = m * inputs[0].getSparsity(); - return (long) (costs + m * (int) (Math.log(m) / Math.log(2)) + m); - case CentralMoment: - if (inputs.length < 1) - throw new RuntimeException("Not all required arguments for CentralMoment operations are passed initialized"); - switch (opcode) { - case "cm_sum": - throw new RuntimeException("Undefined behaviour for CentralMoment operation of type sum"); - case "cm_min": - case "cm_max": - case "cm_count": - costs = 2; - break; - case "cm_mean": - costs = 9; - break; - case "cm_variance": - case "cm_cm2": - costs = 17; - break; - case "cm_cm3": - costs = 32; - break; - case "cm_cm4": - costs = 52; - break; - case "cm_invalid": - // type INVALID used when unknown dimensions - throw new RuntimeException("CentralMoment operation of type INVALID is not supported"); - default: - // at the point of implementation no further supported operations - throw new DMLRuntimeException("CentralMoment operation with type (_) '" + opcode + "' is not supported by SystemDS"); - } - return (long) costs * inputs[0].getCellsWithSparsity(); - case UaggOuterChain: - case Dnn: - throw new RuntimeException("CP operation type'" + instructionType + "' is not supported yet"); - // types corresponding to BinaryCPInstruction - case Binary: - if (opcode.equals("+") || opcode.equals("-")) { - if (inputs.length < 2) - throw new RuntimeException("Not all required arguments for Binary operations +/- are passed initialized"); - return inputs[0].getCellsWithSparsity() + inputs[1].getCellsWithSparsity(); - } else if (opcode.equals("solve")) { - if (inputs.length < 1) - throw new RuntimeException("Not all required arguments for Binary operation 'solve' are passed initialized"); - return inputs[0].getCells() * inputs[0].getN(); - } - if (output == null) - throw new RuntimeException("Not all required arguments for Binary operations are passed initialized"); - switch (opcode) { - case "*": - case "^2": - case "*2": - case "max": - case "min": - case "-nz": - case "==": - case "!=": - case "<": - case ">": - case "<=": - case ">=": - case "&&": - case "||": - case "xor": - case "bitwand": - case "bitwor": - case "bitwxor": - case "bitwshiftl": - case "bitwshiftr": - costs = 1; - break; - case "%/%": - costs = 6; - break; - case "%%": - costs = 8; - break; - case "/": - costs = 22; - break; - case "log": - case "log_nz": - costs = 32; - break; - case "^": - costs = 16; - break; - case "1-*": - costs = 2; - break; - case "dropinvalidtype": - case "dropinvalidlength": - case "freplicate": - case "valueswap": - case "applyschema": - throw new RuntimeException("Specific Frame operation with opcode '" + opcode + "' is not supported yet"); - default: - // at the point of implementation no further supported operations - throw new DMLRuntimeException("Binary operation with opcode '" + opcode + "' is not supported by SystemDS"); - } - return (long) (costs * output.getCells()); - case AggregateBinary: - if (output == null || inputs.length < 2) - throw new RuntimeException("Not all required arguments for AggregateBinary operations are passed initialized"); - // costs represents the cost for matrix transpose - if (opcode.contains("_tl")) costs = inputs[0].getCellsWithSparsity(); - if (opcode.contains("_tr")) costs = inputs[1].getCellsWithSparsity(); - // else ba+*/pmm (or any of cpmm/rmm/mapmm from the Spark instructions) - // reduce by factor of 2: matrix multiplication better than average FLOP count: 2*m*n*p=m*n*p - return (long) (inputs[0].getN() * inputs[0].getSparsity()) * output.getCells() + (long) costs; - case Append: - if (inputs.length < 2) - throw new RuntimeException("Not all required arguments for Append operation is passed initialized"); - return inputs[0].getCellsWithSparsity() * inputs[1].getCellsWithSparsity(); - case Covariance: - if (inputs.length < 1) - throw new RuntimeException("Not all required arguments for Covariance operation is passed initialized"); - return (long) (23 * inputs[0].getM() * inputs[0].getSparsity()); - case QPick: - switch (opcode) { - case "qpick_iqm": - m = inputs[0].getM(); - return (long) (2 * m + //sum of weights - 5 * 0.25d * m + //scan to lower quantile - 8 * 0.5 * m); //scan from lower to upper quantile - case "qpick_median": - case "qpick_valuepick": - case "qpick_rangepick": - throw new RuntimeException("QuantilePickCPInstruction of operation type different from IQM is not supported yet"); - default: - throw new DMLRuntimeException("QPick operation with opcode '" + opcode + "' is not supported by SystemDS"); - } - // types corresponding to others CPInstruction(s) - case Ternary: - if (output == null) - throw new RuntimeException("Not all required arguments for Ternary operation is passed initialized"); - switch (opcode) { - case "+*": - case "-*": - case "ifelse": - return 2 * output.getCells(); - case "_map": - throw new RuntimeException("Specific Frame operation with opcode '" + opcode + "' is not supported yet"); - default: - throw new DMLRuntimeException("Ternary operation with opcode '" + opcode + "' is not supported by SystemDS"); - } - case AggregateTernary: - if (inputs.length < 1) - throw new RuntimeException("Not all required arguments for AggregateTernary operation is passed initialized"); - if (opcode.equals("tak+*") || opcode.equals("tack+*")) - return 6 * inputs[0].getCellsWithSparsity(); - throw new DMLRuntimeException("AggregateTernary operation with opcode '" + opcode + "' is not supported by SystemDS"); - case Quaternary: - //TODO pattern specific and all inputs required - if (inputs.length < 1) - throw new RuntimeException("Not all required arguments for Quaternary operation is passed initialized"); - if (opcode.equals("wsloss") || opcode.equals("wdivmm") || opcode.equals("wcemm")) { - // 4 matrices used - return 4 * inputs[0].getCells(); - } else if (opcode.equals("wsigmoid") || opcode.equals("wumm")) { - // 3 matrices used - return 3 * inputs[0].getCells(); - } - throw new DMLRuntimeException("Quaternary operation with opcode '" + opcode + "' is not supported by SystemDS"); - case BuiltinNary: - if (output == null) - throw new RuntimeException("Not all required arguments for BuiltinNary operation is passed initialized"); - switch (opcode) { - case "cbind": - case "rbind": - return output.getCellsWithSparsity(); - case "nmin": - case "nmax": - case "n+": - return inputs.length * output.getCellsWithSparsity(); - case "printf": - case "list": - return output.getN(); - case "eval": - throw new RuntimeException("EvalNaryCPInstruction is not supported yet"); - default: - throw new DMLRuntimeException("BuiltinNary operation with opcode '" + opcode + "' is not supported by SystemDS"); - } - case Ctable: - if (output == null) - throw new RuntimeException("Not all required arguments for Ctable operation is passed initialized"); - if (opcode.startsWith("ctable")) { - // potential high inaccuracy due to unknown output column size - // and inferring bound on number of elements what could lead to high underestimation - return 3 * output.getCellsWithSparsity(); - } - throw new DMLRuntimeException("Ctable operation with opcode '" + opcode + "' is not supported by SystemDS"); - case PMMJ: - // currently this would never be reached since the pmm instruction uses AggregateBinary op. type - if (output == null || inputs.length < 1) - throw new RuntimeException("Not all required arguments for PMMJ operation is passed initialized"); - if (opcode.equals("pmm")) { - return (long) (inputs[0].getN() * inputs[0].getSparsity()) * output.getCells(); - } - throw new DMLRuntimeException("PMMJ operation with opcode '" + opcode + "' is not supported by SystemDS"); - case ParameterizedBuiltin: - // no argument validation here since the logic is not fully defined for this operation - m = inputs[0].getM(); - switch (opcode) { - case "contains": - case "replace": - case "tostring": - return inputs[0].getCells(); - case "nvlist": - case "cdf": - case "invcdf": - case "lowertri": - case "uppertri": - case "rexpand": - return output.getCells(); - case "rmempty_rows": - return (long) (inputs[0].getM() * Math.ceil(1.0d / inputs[0].getSparsity()) / 2) - + output.getCells(); - case "rmempty_cols": - return (long) (inputs[0].getN() * Math.ceil(1.0d / inputs[0].getSparsity()) / 2) - + output.getCells(); - // opcode: "groupedagg" - case "groupedagg_count": - case "groupedagg_min": - case "groupedagg_max": - return 2 * m + m; - case "groupedagg_sum": - return 2 * m + 4 * m; - case "groupedagg_mean": - return 2 * m + 8 * m; - case "groupedagg_cm2": - return 2 * m + 16 * m; - case "groupedagg_cm3": - return 2 * m + 31 * m; - case "groupedagg_cm4": - return 2 * m + 51 * m; - case "groupedagg_variance": - return 2 * m + 16 * m; - case "groupedagg_invalid": - // type INVALID used when unknown dimensions - throw new RuntimeException("ParameterizedBuiltin operation with opcode 'groupedagg' of type INVALID is not supported"); - case "tokenize": - case "transformapply": - case "transformdecode": - case "transformcolmap": - case "transformmeta": - case "autodiff": - case "paramserv": - throw new RuntimeException("ParameterizedBuiltin operation with opcode '" + opcode + "' is not supported yet"); - default: - throw new DMLRuntimeException("ParameterizedBuiltin operation with opcode '" + opcode + "' is not supported by SystemDS"); - } - case MultiReturnBuiltin: - if (inputs.length < 1) - throw new RuntimeException("Not all required arguments for MultiReturnBuiltin operation is passed initialized"); - switch (opcode) { - case "qr": - costs = 2; - break; - case "lu": - costs = 16; - break; - case "eigen": - case "svd": - costs = 32; - break; - case "fft": - case "fft_linearized": - throw new RuntimeException("MultiReturnBuiltin operation with opcode '" + opcode + "' is not supported yet"); - default: - throw new DMLRuntimeException(" MultiReturnBuiltin operation with opcode '" + opcode + "' is not supported by SystemDS"); - } - return (long) (costs * inputs[0].getCells() * inputs[0].getN()); - case Prefetch: - case EvictLineageCache: - case Broadcast: - case Local: - case FCall: - case NoOp: - // not directly related to computation - return 0; - case Variable: - case Rand: - case StringInit: - throw new RuntimeException(instructionType + " instructions are not handled by this method"); - case MultiReturnParameterizedBuiltin: // opcodes: transformencode - case MultiReturnComplexMatrixBuiltin: // opcodes: ifft, ifft_linearized, stft, rcm - case Compression: // opcode: compress - case DeCompression: // opcode: decompress - throw new RuntimeException("CP operation type'" + instructionType + "' is not supported yet"); - case TrigRemote: - case Partition: - case SpoofFused: - case Sql: - throw new RuntimeException("CP operation type'" + instructionType + "' is not planned for support"); - default: - // no further supported CP types - throw new DMLRuntimeException("CP operation type'" + instructionType + "' is not supported by SystemDS"); - } - } + /** + * + * @param instructionType instruction type + * @param opcode instruction opcode, potentially with suffix to mark an extra op. characteristic + * @param output output's variable statistics, null is not needed for the estimation + * @param inputs any inputs' variable statistics, no object passed is not needed for estimation + * @return estimated number of floating point operations + */ + public static long getInstNFLOP( + CPType instructionType, + String opcode, + VarStats output, + VarStats...inputs + ) { + opcode = opcode.toLowerCase(); // enforce lowercase for convince + long m; + double costs = 0; + switch (instructionType) { + // types corresponding to UnaryCPInstruction + case Unary: + case Builtin: // log and log_nz only + if (output == null || inputs.length < 1) + throw new RuntimeException("Not all required arguments for Unary/Builtin operations are passed initialized"); + double sparsity = inputs[0].getSparsity(); + switch (opcode) { + case "!": + case "isna": + case "isnan": + case "isinf": + case "ceil": + case "floor": + costs = 1; + break; + case "abs": + case "round": + case "sign": + costs = 1 * sparsity; + break; + case "sprop": + case "sqrt": + costs = 2 * sparsity; + break; + case "exp": + costs = 18 * sparsity; + break; + case "sigmoid": + costs = 21 * sparsity; + break; + case "log": + costs = 32; + break; + case "log_nz": + case "plogp": + costs = 32 * sparsity; + break; + case "print": + case "assert": + costs = 1; + break; + case "sin": + costs = 18 * sparsity; + break; + case "cos": + costs = 22 * inputs[0].getSparsity(); + break; + case "tan": + costs = 42 * inputs[0].getSparsity(); + break; + case "asin": + case "sinh": + costs = 93; + break; + case "acos": + case "cosh": + costs = 103; + break; + case "atan": + case "tanh": + costs = 40; + break; + case "ucumk+": + case "ucummin": + case "ucummax": + case "ucum*": + costs = 1 * sparsity; + break; + case "ucumk+*": + costs = 2 * sparsity; + break; + case "stop": + costs = 0; + break; + case "typeof": + costs = 1; + break; + case "inverse": + costs = (4.0 / 3.0) * output.getCellsWithSparsity() * output.getCellsWithSparsity(); + break; + case "cholesky": + costs = (1.0 / 3.0) * output.getCellsWithSparsity() * output.getCellsWithSparsity(); + break; + case "detectschema": + case "colnames": + throw new RuntimeException("Specific Frame operation with opcode '" + opcode + "' is not supported yet"); + default: + // at the point of implementation no further supported operations + throw new DMLRuntimeException("Unary operation with opcode '" + opcode + "' is not supported by SystemDS"); + } + return (long) (costs * output.getCells()); + case AggregateUnary: + if (output == null || inputs.length < 1) + throw new RuntimeException("Not all required arguments for AggregateUnary operations are passed initialized"); + switch (opcode) { + case "nrow": + case "ncol": + case "length": + case "exists": + case "lineage": + return DEFAULT_NFLOP_NOOP; + case "uak+": + case "uark+": + case "uack+": + costs = 4; + break; + case "uasqk+": + case "uarsqk+": + case "uacsqk+": + costs = 5; + break; + case "uamean": + case "uarmean": + case "uacmean": + costs = 7; + break; + case "uavar": + case "uarvar": + case "uacvar": + costs = 14; + break; + case "uamax": + case "uarmax": + case "uarimax": + case "uacmax": + case "uamin": + case "uarmin": + case "uarimin": + case "uacmin": + costs = 1; + break; + case "ua+": + case "uar+": + case "uac+": + case "ua*": + case "uar*": + case "uac*": + costs = 1 * output.getSparsity(); + break; + // count distinct operations + case "uacd": + case "uacdr": + case "uacdc": + case "unique": + case "uniquer": + case "uniquec": + costs = 1 * output.getSparsity(); + break; + case "uacdap": + case "uacdapr": + case "uacdapc": + costs = 0.5 * output.getSparsity(); // do not iterate through all the cells + break; + // aggregation over the diagonal of a square matrix + case "uatrace": + case "uaktrace": + return inputs[0].getM(); + default: + // at the point of implementation no further supported operations + throw new DMLRuntimeException("AggregateUnary operation with opcode '" + opcode + "' is not supported by SystemDS"); + } + // scale + if (opcode.startsWith("uar")) { + costs *= inputs[0].getM(); + } else if (opcode.startsWith("uac")) { + costs *= inputs[0].getN(); + } else { + costs *= inputs[0].getCells(); + } + return (long) (costs * output.getCells()); + case MMTSJ: + if (inputs.length < 1) + throw new RuntimeException("Not all required arguments for MMTSJ operations are passed initialized"); + // reduce by factor of 4: matrix multiplication better than average FLOP count + // + multiply only upper triangular + if (opcode.equals("tsmm_left")) { + costs = inputs[0].getN() * (inputs[0].getSparsity() / 2); + } else { // tsmm/tsmm_right + costs = inputs[0].getM() * (inputs[0].getSparsity() / 2); + } + return (long) (costs * inputs[0].getCellsWithSparsity()); + case Reorg: + case Reshape: + if (output == null) + throw new RuntimeException("Not all required arguments for Reorg/Reshape operations are passed initialized"); + if (opcode.equals("rsort")) + return (long) (output.getCellsWithSparsity() * (Math.log(output.getM()) / Math.log(2))); // merge sort columns (n*m*log2(m)) + return output.getCellsWithSparsity(); + case MatrixIndexing: + if (output == null) + throw new RuntimeException("Not all required arguments for Indexing operations are passed initialized"); + return output.getCellsWithSparsity(); + case MMChain: + if (inputs.length < 1) + throw new RuntimeException("Not all required arguments for MMChain operations are passed initialized"); + // reduction by factor 2 because matrix mult better than average flop count + // (mmchain essentially two matrix-vector muliplications) + return (2 + 2) * inputs[0].getCellsWithSparsity() / 2; + case QSort: + if (inputs.length < 1) + throw new RuntimeException("Not all required arguments for QSort operations are passed initialized"); + // mergesort since comparator used + m = inputs[0].getM(); + if (opcode.equals("qsort")) + costs = m + m; + else // == "qsort_wts" (with weights) + costs = m * inputs[0].getSparsity(); + return (long) (costs + m * (int) (Math.log(m) / Math.log(2)) + m); + case CentralMoment: + if (inputs.length < 1) + throw new RuntimeException("Not all required arguments for CentralMoment operations are passed initialized"); + switch (opcode) { + case "cm_sum": + throw new RuntimeException("Undefined behaviour for CentralMoment operation of type sum"); + case "cm_min": + case "cm_max": + case "cm_count": + costs = 2; + break; + case "cm_mean": + costs = 9; + break; + case "cm_variance": + case "cm_cm2": + costs = 17; + break; + case "cm_cm3": + costs = 32; + break; + case "cm_cm4": + costs = 52; + break; + case "cm_invalid": + // type INVALID used when unknown dimensions + throw new RuntimeException("CentralMoment operation of type INVALID is not supported"); + default: + // at the point of implementation no further supported operations + throw new DMLRuntimeException("CentralMoment operation with type (_) '" + opcode + "' is not supported by SystemDS"); + } + return (long) costs * inputs[0].getCellsWithSparsity(); + case UaggOuterChain: + case Dnn: + throw new RuntimeException("CP operation type'" + instructionType + "' is not supported yet"); + // types corresponding to BinaryCPInstruction + case Binary: + if (opcode.equals("+") || opcode.equals("-")) { + if (inputs.length < 2) + throw new RuntimeException("Not all required arguments for Binary operations +/- are passed initialized"); + return inputs[0].getCellsWithSparsity() + inputs[1].getCellsWithSparsity(); + } else if (opcode.equals("solve")) { + if (inputs.length < 1) + throw new RuntimeException("Not all required arguments for Binary operation 'solve' are passed initialized"); + return inputs[0].getCells() * inputs[0].getN(); + } + if (output == null) + throw new RuntimeException("Not all required arguments for Binary operations are passed initialized"); + switch (opcode) { + case "*": + case "^2": + case "*2": + case "max": + case "min": + case "-nz": + case "==": + case "!=": + case "<": + case ">": + case "<=": + case ">=": + case "&&": + case "||": + case "xor": + case "bitwand": + case "bitwor": + case "bitwxor": + case "bitwshiftl": + case "bitwshiftr": + costs = 1; + break; + case "%/%": + costs = 6; + break; + case "%%": + costs = 8; + break; + case "/": + costs = 22; + break; + case "log": + case "log_nz": + costs = 32; + break; + case "^": + costs = 16; + break; + case "1-*": + costs = 2; + break; + case "dropinvalidtype": + case "dropinvalidlength": + case "freplicate": + case "valueswap": + case "applyschema": + throw new RuntimeException("Specific Frame operation with opcode '" + opcode + "' is not supported yet"); + default: + // at the point of implementation no further supported operations + throw new DMLRuntimeException("Binary operation with opcode '" + opcode + "' is not supported by SystemDS"); + } + return (long) (costs * output.getCells()); + case AggregateBinary: + if (output == null || inputs.length < 2) + throw new RuntimeException("Not all required arguments for AggregateBinary operations are passed initialized"); + // costs represents the cost for matrix transpose + if (opcode.contains("_tl")) costs = inputs[0].getCellsWithSparsity(); + if (opcode.contains("_tr")) costs = inputs[1].getCellsWithSparsity(); + // else ba+*/pmm (or any of cpmm/rmm/mapmm from the Spark instructions) + // reduce by factor of 2: matrix multiplication better than average FLOP count: 2*m*n*p=m*n*p + return (long) (inputs[0].getN() * inputs[0].getSparsity()) * output.getCells() + (long) costs; + case Append: + if (inputs.length < 2) + throw new RuntimeException("Not all required arguments for Append operation is passed initialized"); + return inputs[0].getCellsWithSparsity() * inputs[1].getCellsWithSparsity(); + case Covariance: + if (inputs.length < 1) + throw new RuntimeException("Not all required arguments for Covariance operation is passed initialized"); + return (long) (23 * inputs[0].getM() * inputs[0].getSparsity()); + case QPick: + switch (opcode) { + case "qpick_iqm": + m = inputs[0].getM(); + return (long) (2 * m + //sum of weights + 5 * 0.25d * m + //scan to lower quantile + 8 * 0.5 * m); //scan from lower to upper quantile + case "qpick_median": + case "qpick_valuepick": + case "qpick_rangepick": + throw new RuntimeException("QuantilePickCPInstruction of operation type different from IQM is not supported yet"); + default: + throw new DMLRuntimeException("QPick operation with opcode '" + opcode + "' is not supported by SystemDS"); + } + // types corresponding to others CPInstruction(s) + case Ternary: + if (output == null) + throw new RuntimeException("Not all required arguments for Ternary operation is passed initialized"); + switch (opcode) { + case "+*": + case "-*": + case "ifelse": + return 2 * output.getCells(); + case "_map": + throw new RuntimeException("Specific Frame operation with opcode '" + opcode + "' is not supported yet"); + default: + throw new DMLRuntimeException("Ternary operation with opcode '" + opcode + "' is not supported by SystemDS"); + } + case AggregateTernary: + if (inputs.length < 1) + throw new RuntimeException("Not all required arguments for AggregateTernary operation is passed initialized"); + if (opcode.equals("tak+*") || opcode.equals("tack+*")) + return 6 * inputs[0].getCellsWithSparsity(); + throw new DMLRuntimeException("AggregateTernary operation with opcode '" + opcode + "' is not supported by SystemDS"); + case Quaternary: + //TODO pattern specific and all inputs required + if (inputs.length < 1) + throw new RuntimeException("Not all required arguments for Quaternary operation is passed initialized"); + if (opcode.equals("wsloss") || opcode.equals("wdivmm") || opcode.equals("wcemm")) { + // 4 matrices used + return 4 * inputs[0].getCells(); + } else if (opcode.equals("wsigmoid") || opcode.equals("wumm")) { + // 3 matrices used + return 3 * inputs[0].getCells(); + } + throw new DMLRuntimeException("Quaternary operation with opcode '" + opcode + "' is not supported by SystemDS"); + case BuiltinNary: + if (output == null) + throw new RuntimeException("Not all required arguments for BuiltinNary operation is passed initialized"); + switch (opcode) { + case "cbind": + case "rbind": + return output.getCellsWithSparsity(); + case "nmin": + case "nmax": + case "n+": + return inputs.length * output.getCellsWithSparsity(); + case "printf": + case "list": + return output.getN(); + case "eval": + throw new RuntimeException("EvalNaryCPInstruction is not supported yet"); + default: + throw new DMLRuntimeException("BuiltinNary operation with opcode '" + opcode + "' is not supported by SystemDS"); + } + case Ctable: + if (output == null) + throw new RuntimeException("Not all required arguments for Ctable operation is passed initialized"); + if (opcode.startsWith("ctable")) { + // potential high inaccuracy due to unknown output column size + // and inferring bound on number of elements what could lead to high underestimation + return 3 * output.getCellsWithSparsity(); + } + throw new DMLRuntimeException("Ctable operation with opcode '" + opcode + "' is not supported by SystemDS"); + case PMMJ: + // currently this would never be reached since the pmm instruction uses AggregateBinary op. type + if (output == null || inputs.length < 1) + throw new RuntimeException("Not all required arguments for PMMJ operation is passed initialized"); + if (opcode.equals("pmm")) { + return (long) (inputs[0].getN() * inputs[0].getSparsity()) * output.getCells(); + } + throw new DMLRuntimeException("PMMJ operation with opcode '" + opcode + "' is not supported by SystemDS"); + case ParameterizedBuiltin: + // no argument validation here since the logic is not fully defined for this operation + m = inputs[0].getM(); + switch (opcode) { + case "contains": + case "replace": + case "tostring": + return inputs[0].getCells(); + case "nvlist": + case "cdf": + case "invcdf": + case "lowertri": + case "uppertri": + case "rexpand": + return output.getCells(); + case "rmempty_rows": + return (long) (inputs[0].getM() * Math.ceil(1.0d / inputs[0].getSparsity()) / 2) + + output.getCells(); + case "rmempty_cols": + return (long) (inputs[0].getN() * Math.ceil(1.0d / inputs[0].getSparsity()) / 2) + + output.getCells(); + // opcode: "groupedagg" + case "groupedagg_count": + case "groupedagg_min": + case "groupedagg_max": + return 2 * m + m; + case "groupedagg_sum": + return 2 * m + 4 * m; + case "groupedagg_mean": + return 2 * m + 8 * m; + case "groupedagg_cm2": + return 2 * m + 16 * m; + case "groupedagg_cm3": + return 2 * m + 31 * m; + case "groupedagg_cm4": + return 2 * m + 51 * m; + case "groupedagg_variance": + return 2 * m + 16 * m; + case "groupedagg_invalid": + // type INVALID used when unknown dimensions + throw new RuntimeException("ParameterizedBuiltin operation with opcode 'groupedagg' of type INVALID is not supported"); + case "tokenize": + case "transformapply": + case "transformdecode": + case "transformcolmap": + case "transformmeta": + case "autodiff": + case "paramserv": + throw new RuntimeException("ParameterizedBuiltin operation with opcode '" + opcode + "' is not supported yet"); + default: + throw new DMLRuntimeException("ParameterizedBuiltin operation with opcode '" + opcode + "' is not supported by SystemDS"); + } + case MultiReturnBuiltin: + if (inputs.length < 1) + throw new RuntimeException("Not all required arguments for MultiReturnBuiltin operation is passed initialized"); + switch (opcode) { + case "qr": + costs = 2; + break; + case "lu": + costs = 16; + break; + case "eigen": + case "svd": + costs = 32; + break; + case "fft": + case "fft_linearized": + throw new RuntimeException("MultiReturnBuiltin operation with opcode '" + opcode + "' is not supported yet"); + default: + throw new DMLRuntimeException(" MultiReturnBuiltin operation with opcode '" + opcode + "' is not supported by SystemDS"); + } + return (long) (costs * inputs[0].getCells() * inputs[0].getN()); + case Prefetch: + case EvictLineageCache: + case Broadcast: + case Local: + case FCall: + case NoOp: + // not directly related to computation + return 0; + case Variable: + case Rand: + case StringInit: + throw new RuntimeException(instructionType + " instructions are not handled by this method"); + case MultiReturnParameterizedBuiltin: // opcodes: transformencode + case MultiReturnComplexMatrixBuiltin: // opcodes: ifft, ifft_linearized, stft, rcm + case Compression: // opcode: compress + case DeCompression: // opcode: decompress + throw new RuntimeException("CP operation type'" + instructionType + "' is not supported yet"); + case TrigRemote: + case Partition: + case SpoofFused: + case Sql: + throw new RuntimeException("CP operation type'" + instructionType + "' is not planned for support"); + default: + // no further supported CP types + throw new DMLRuntimeException("CP operation type'" + instructionType + "' is not supported by SystemDS"); + } + } } diff --git a/src/main/java/org/apache/sysds/resource/cost/SparkCostUtils.java b/src/main/java/org/apache/sysds/resource/cost/SparkCostUtils.java index 479cf0f24fe..9d9851907a1 100644 --- a/src/main/java/org/apache/sysds/resource/cost/SparkCostUtils.java +++ b/src/main/java/org/apache/sysds/resource/cost/SparkCostUtils.java @@ -38,775 +38,775 @@ public class SparkCostUtils { - public static double getReblockInstTime(String opcode, VarStats input, VarStats output, IOMetrics executorMetrics) { - // Reblock triggers a new stage - // old stage: read text file + shuffle the intermediate text rdd - double readTime = getHadoopReadTime(input, executorMetrics); - long sizeTextFile = OptimizerUtils.estimateSizeTextOutput(input.getM(), input.getN(), input.getNNZ(), (Types.FileFormat) input.fileInfo[1]); - RDDStats textRdd = new RDDStats(sizeTextFile, -1); - double shuffleTime = getSparkShuffleTime(textRdd, executorMetrics, true); - double timeStage1 = readTime + shuffleTime; - // new stage: transform partitioned shuffled text object into partitioned binary object - long nflop = getInstNFLOP(SPType.Reblock, opcode, output); - double timeStage2 = getCPUTime(nflop, textRdd.numPartitions, executorMetrics, output.rddStats, textRdd); - return timeStage1 + timeStage2; - } - - public static double getRandInstTime(String opcode, int randType, VarStats output, IOMetrics executorMetrics) { - if (opcode.equals(SAMPLE_OPCODE)) { - // sample uses sortByKey() op. and it should be handled differently - throw new RuntimeException("Spark operation Rand with opcode " + SAMPLE_OPCODE + " is not supported yet"); - } - - long nflop; - if (opcode.equals(RAND_OPCODE) || opcode.equals(FRAME_OPCODE)) { - if (randType == 0) return 0; // empty matrix - else if (randType == 1) nflop = 8; // allocate, array fill - else if (randType == 2) nflop = 32; // full rand - else throw new RuntimeException("Unknown type of random instruction"); - } else if (opcode.equals(SEQ_OPCODE)) { - nflop = 1; - } else { - throw new DMLRuntimeException("Rand operation with opcode '" + opcode + "' is not supported by SystemDS"); - } - nflop *= output.getCells(); - // no shuffling required -> only computation time - return getCPUTime(nflop, output.rddStats.numPartitions, executorMetrics, output.rddStats); - } - - public static double getUnaryInstTime(String opcode, VarStats input, VarStats output, IOMetrics executorMetrics) { - // handles operations of type Builtin as Unary - // Unary adds a map() to an open stage - long nflop = getInstNFLOP(SPType.Unary, opcode, output, input); - double mapTime = getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, input.rddStats); - // the resulting rdd is being hash-partitioned depending on the input one - output.rddStats.hashPartitioned = input.rddStats.hashPartitioned; - return mapTime; - } - - public static double getAggUnaryInstTime(UnarySPInstruction inst, VarStats input, VarStats output, IOMetrics executorMetrics) { - // AggregateUnary results in different Spark execution plan depending on the output dimensions - String opcode = inst.getOpcode(); - AggBinaryOp.SparkAggType aggType = (inst instanceof AggregateUnarySPInstruction)? - ((AggregateUnarySPInstruction) inst).getAggType(): - ((AggregateUnarySketchSPInstruction) inst).getAggType(); - double shuffleTime; - if (inst instanceof CumulativeAggregateSPInstruction) { - shuffleTime = getSparkShuffleTime(output.rddStats, executorMetrics, true); - output.rddStats.hashPartitioned = true; - } else { - if (aggType == AggBinaryOp.SparkAggType.SINGLE_BLOCK) { - // loading RDD to the driver (CP) explicitly (not triggered by CP instruction) - output.rddStats.isCollected = true; - // cost for transferring result values (result from fold()) is negligible -> cost = computation time - shuffleTime = 0; - } else if (aggType == AggBinaryOp.SparkAggType.MULTI_BLOCK) { - // combineByKey() triggers a new stage -> cost = computation time + shuffle time (combineByKey); - if (opcode.equals("uaktrace")) { - long diagonalBlockSize = OptimizerUtils.estimatePartitionedSizeExactSparsity( - input.characteristics.getBlocksize() * input.getM(), - input.characteristics.getBlocksize(), - input.characteristics.getBlocksize(), - input.getNNZ() - ); - RDDStats filteredRDD = new RDDStats(diagonalBlockSize, input.rddStats.numPartitions); - shuffleTime = getSparkShuffleTime(filteredRDD, executorMetrics, true); - } else { - shuffleTime = getSparkShuffleTime(input.rddStats, executorMetrics, true); - } - output.rddStats.hashPartitioned = true; - output.rddStats.numPartitions = input.rddStats.numPartitions; - } else { // aggType == AggBinaryOp.SparkAggType.NONE - output.rddStats.hashPartitioned = input.rddStats.hashPartitioned; - output.rddStats.numPartitions = input.rddStats.numPartitions; - // only mapping transformation -> cost = computation time - shuffleTime = 0; - } - } - long nflop = getInstNFLOP(SPType.AggregateUnary, opcode, output, input); - double mapTime = getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, input.rddStats); - return shuffleTime + mapTime; - } - - public static double getIndexingInstTime(IndexingSPInstruction inst, VarStats input1, VarStats input2, VarStats output, IOMetrics driverMetrics, IOMetrics executorMetrics) { - String opcode = inst.getOpcode(); - double dataTransmissionTime; - if (opcode.equals(RightIndex.OPCODE)) { - // assume direct collecting if output dimensions not larger than block size - int blockSize = ConfigurationManager.getBlocksize(); - if (output.getM() <= blockSize && output.getN() <= blockSize) { - // represents single block and multi block cases - dataTransmissionTime = getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics); - output.rddStats.isCollected = true; - } else { - // represents general indexing: worst case: shuffling required - dataTransmissionTime = getSparkShuffleTime(output.rddStats, executorMetrics, true); - } - } else if (opcode.equals(LeftIndex.OPCODE)) { - // model combineByKey() with shuffling the second input - dataTransmissionTime = getSparkShuffleTime(input2.rddStats, executorMetrics, true); - } else { // mapLeftIndex - dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics); - } - long nflop = getInstNFLOP(SPType.MatrixIndexing, opcode, output); - // scan only the size of the output since filter is applied first - RDDStats[] objectsToScan = (input2 == null)? new RDDStats[]{output.rddStats} : - new RDDStats[]{output.rddStats, output.rddStats}; - double mapTime = getCPUTime(nflop, input1.rddStats.numPartitions, executorMetrics, output.rddStats, objectsToScan); - return dataTransmissionTime + mapTime; - } - - public static double getBinaryInstTime(SPInstruction inst, VarStats input1, VarStats input2, VarStats output, IOMetrics driverMetrics, IOMetrics executorMetrics) { - SPType opType = inst.getSPInstructionType(); - String opcode = inst.getOpcode(); - // binary, builtin binary (log and log_nz) - // for the NFLOP calculation if the function is executed as map is not relevant - if (opcode.startsWith("map")) { - opcode = opcode.substring(3); - } - double dataTransmissionTime; - if (inst instanceof BinaryMatrixMatrixSPInstruction) { - if (inst instanceof BinaryMatrixBVectorSPInstruction) { - // the second matrix is always the broadcast one - dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics); - // flatMapToPair() or ()mapPartitionsToPair invoked -> no shuffling - output.rddStats.numPartitions = input1.rddStats.numPartitions; - output.rddStats.hashPartitioned = input1.rddStats.hashPartitioned; - } else { // regular BinaryMatrixMatrixSPInstruction - // join() input1 and input2 - dataTransmissionTime = getSparkShuffleWriteTime(input1.rddStats, executorMetrics) + - getSparkShuffleWriteTime(input2.rddStats, executorMetrics); - if (input1.rddStats.hashPartitioned) { - output.rddStats.numPartitions = input1.rddStats.numPartitions; - if (!input2.rddStats.hashPartitioned || !(input1.rddStats.numPartitions == input2.rddStats.numPartitions)) { - // shuffle needed for join() -> actual shuffle only for input2 - dataTransmissionTime += getSparkShuffleReadStaticTime(input1.rddStats, executorMetrics) + - getSparkShuffleReadTime(input2.rddStats, executorMetrics); - } else { // no shuffle needed for join() -> only read from local disk - dataTransmissionTime += getSparkShuffleReadStaticTime(input1.rddStats, executorMetrics) + - getSparkShuffleReadStaticTime(input2.rddStats, executorMetrics); - } - } else if (input2.rddStats.hashPartitioned) { - output.rddStats.numPartitions = input2.rddStats.numPartitions; - // input1 not hash partitioned: shuffle needed for join() -> actual shuffle only for input2 - dataTransmissionTime += getSparkShuffleReadStaticTime(input1.rddStats, executorMetrics) + - getSparkShuffleReadTime(input2.rddStats, executorMetrics); - } else { - // repartition all data needed - output.rddStats.numPartitions = 2 * output.rddStats.numPartitions; - dataTransmissionTime += getSparkShuffleReadTime(input1.rddStats, executorMetrics) + - getSparkShuffleReadTime(input2.rddStats, executorMetrics); - } - output.rddStats.hashPartitioned = true; - } - } else if (inst instanceof BinaryMatrixScalarSPInstruction) { - // only mapValues() invoked -> no shuffling - dataTransmissionTime = 0; - output.rddStats.hashPartitioned = (input2.isScalar())? input1.rddStats.hashPartitioned : input2.rddStats.hashPartitioned; - } else if (inst instanceof BinaryFrameMatrixSPInstruction || inst instanceof BinaryFrameFrameSPInstruction) { - throw new RuntimeException("Handling binary instructions for frames not handled yet."); - } else { - throw new RuntimeException("Not supported binary instruction: "+inst); - } - long nflop = getInstNFLOP(opType, opcode, output, input1, input2); - double mapTime = getCPUTime(nflop, output.rddStats.numPartitions, executorMetrics, output.rddStats, input1.rddStats, input2.rddStats); - return dataTransmissionTime + mapTime; - } - - public static double getAppendInstTime(AppendSPInstruction inst, VarStats input1, VarStats input2, VarStats output, IOMetrics driverMetrics, IOMetrics executorMetrics) { - double dataTransmissionTime; - if (inst instanceof AppendMSPInstruction) { - dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics); - output.rddStats.hashPartitioned = true; - } else if (inst instanceof AppendRSPInstruction) { - dataTransmissionTime = getSparkShuffleTime(output.rddStats, executorMetrics, false); - } else if (inst instanceof AppendGAlignedSPInstruction) { - // only changing matrix indexing - dataTransmissionTime = 0; - } else { // AppendGSPInstruction - // shuffle the whole appended matrix - dataTransmissionTime = getSparkShuffleTime(input2.rddStats, executorMetrics, true); - output.rddStats.hashPartitioned = true; - } - // opcode not relevant for the nflop estimation of append instructions; - long nflop = getInstNFLOP(inst.getSPInstructionType(), "append", output, input1, input2); - double mapTime = getCPUTime(nflop, output.rddStats.numPartitions, executorMetrics, output.rddStats, input1.rddStats, input2.rddStats); - return dataTransmissionTime + mapTime; - } - - public static double getReorgInstTime(UnarySPInstruction inst, VarStats input, VarStats output, IOMetrics executorMetrics) { - // includes logic for MatrixReshapeSPInstruction - String opcode = inst.getOpcode(); - double dataTransmissionTime; - switch (opcode) { - case "rshape": - dataTransmissionTime = getSparkShuffleTime(input.rddStats, executorMetrics, true); - output.rddStats.hashPartitioned = true; - break; - case "r'": - dataTransmissionTime = 0; - output.rddStats.hashPartitioned = input.rddStats.hashPartitioned; - break; - case "rev": - dataTransmissionTime = getSparkShuffleTime(output.rddStats, executorMetrics, true); - output.rddStats.hashPartitioned = true; - break; - case "rdiag": - dataTransmissionTime = 0; - output.rddStats.numPartitions = input.rddStats.numPartitions; - output.rddStats.hashPartitioned = input.rddStats.hashPartitioned; - break; - default: // rsort - String ixretAsString = InstructionUtils.getInstructionParts(inst.getInstructionString())[4]; - boolean ixret = ixretAsString.equalsIgnoreCase("true"); - int shuffleFactor; - if (ixret) { // index return - shuffleFactor = 2; // estimate cost for 2 shuffles - } else { - shuffleFactor = 4;// estimate cost for 2 shuffles - } - // assume case: 4 times shuffling the output - dataTransmissionTime = getSparkShuffleWriteTime(output.rddStats, executorMetrics) + - getSparkShuffleReadTime(output.rddStats, executorMetrics); - dataTransmissionTime *= shuffleFactor; - break; - } - long nflop = getInstNFLOP(inst.getSPInstructionType(), opcode, output); // uses output only - double mapTime = getCPUTime(nflop, output.rddStats.numPartitions, executorMetrics, output.rddStats, input.rddStats); - return dataTransmissionTime + mapTime; - } - - public static double getTSMMInstTime(UnarySPInstruction inst, VarStats input, VarStats output, IOMetrics driverMetrics, IOMetrics executorMetrics) { - String opcode = inst.getOpcode(); - MMTSJ.MMTSJType type; - - double dataTransmissionTime; - if (inst instanceof TsmmSPInstruction) { - type = ((TsmmSPInstruction) inst).getMMTSJType(); - // fold() used but result is still a whole matrix block - dataTransmissionTime = getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics); - output.rddStats.isCollected = true; - } else { // Tsmm2SPInstruction - type = ((Tsmm2SPInstruction) inst).getMMTSJType(); - // assumes always default output with collect - long rowsRange = (type == MMTSJ.MMTSJType.LEFT)? input.getM() : - input.getM() - input.characteristics.getBlocksize(); - long colsRange = (type != MMTSJ.MMTSJType.LEFT)? input.getN() : - input.getN() - input.characteristics.getBlocksize(); - VarStats broadcast = new VarStats("tmp1", new MatrixCharacteristics(rowsRange, colsRange)); - broadcast.rddStats = new RDDStats(broadcast); - dataTransmissionTime = getSparkCollectTime(broadcast.rddStats, driverMetrics, executorMetrics); - dataTransmissionTime += getSparkBroadcastTime(broadcast, driverMetrics, executorMetrics); - dataTransmissionTime += getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics); - } - opcode += type.isLeft() ? "_left" : "_right"; - long nflop = getInstNFLOP(inst.getSPInstructionType(), opcode, output, input); - double mapTime = getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, input.rddStats); - return dataTransmissionTime + mapTime; - } - - public static double getCentralMomentInstTime(CentralMomentSPInstruction inst, VarStats input, VarStats weights, VarStats output, IOMetrics executorMetrics) { - CMOperator.AggregateOperationTypes opType = ((CMOperator) inst.getOperator()).getAggOpType(); - String opcode = inst.getOpcode() + "_" + opType.name().toLowerCase(); - - double dataTransmissionTime = 0; - if (weights != null) { - dataTransmissionTime = getSparkShuffleWriteTime(weights.rddStats, executorMetrics) + - getSparkShuffleReadTime(weights.rddStats, executorMetrics); - - } - output.rddStats.isCollected = true; - - RDDStats[] RDDInputs = (weights == null)? new RDDStats[]{input.rddStats} : new RDDStats[]{input.rddStats, weights.rddStats}; - long nflop = getInstNFLOP(inst.getSPInstructionType(), opcode, output, input); - double mapTime = getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, RDDInputs); - return dataTransmissionTime + mapTime; - } - - public static double getCastInstTime(CastSPInstruction inst, VarStats input, VarStats output, IOMetrics executorMetrics) { - double shuffleTime = 0; - if (input.getN() > input.characteristics.getBlocksize()) { - shuffleTime = getSparkShuffleWriteTime(input.rddStats, executorMetrics) + - getSparkShuffleReadTime(input.rddStats, executorMetrics); - output.rddStats.hashPartitioned = true; - } - long nflop = getInstNFLOP(inst.getSPInstructionType(), inst.getOpcode(), output, input); - double mapTime = getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, input.rddStats); - return shuffleTime + mapTime; - } - - public static double getQSortInstTime(QuantileSortSPInstruction inst, VarStats input, VarStats weights, VarStats output, IOMetrics executorMetrics) { - String opcode = inst.getOpcode(); - double shuffleTime = 0; - if (weights != null) { - opcode += "_wts"; - shuffleTime += getSparkShuffleWriteTime(weights.rddStats, executorMetrics) + - getSparkShuffleReadTime(weights.rddStats, executorMetrics); - } - shuffleTime += getSparkShuffleWriteTime(output.rddStats, executorMetrics) + - getSparkShuffleReadTime(output.rddStats, executorMetrics); - output.rddStats.hashPartitioned = true; - - long nflop = getInstNFLOP(SPType.QSort, opcode, output, input, weights); - RDDStats[] RDDInputs = (weights == null)? new RDDStats[]{input.rddStats} : new RDDStats[]{input.rddStats, weights.rddStats}; - double mapTime = getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, RDDInputs); - return shuffleTime + mapTime; - } - - public static double getMatMulInstTime(BinarySPInstruction inst, VarStats input1, VarStats input2, VarStats output, IOMetrics driverMetrics, IOMetrics executorMetrics) { - double dataTransmissionTime; - int numPartitionsForMapping; - if (inst instanceof CpmmSPInstruction) { - CpmmSPInstruction cpmminst = (CpmmSPInstruction) inst; - AggBinaryOp.SparkAggType aggType = cpmminst.getAggType(); - // estimate for in1.join(in2) - long joinedSize = input1.rddStats.distributedSize + input2.rddStats.distributedSize; - RDDStats joinedRDD = new RDDStats(joinedSize, -1); - dataTransmissionTime = getSparkShuffleTime(joinedRDD, executorMetrics, true); - if (aggType == AggBinaryOp.SparkAggType.SINGLE_BLOCK) { - dataTransmissionTime += getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics); - output.rddStats.isCollected = true; - } else { - dataTransmissionTime += getSparkShuffleTime(output.rddStats, executorMetrics, true); - output.rddStats.hashPartitioned = true; - } - numPartitionsForMapping = joinedRDD.numPartitions; - } else if (inst instanceof RmmSPInstruction) { - // estimate for in1.join(in2) - long joinedSize = input1.rddStats.distributedSize + input2.rddStats.distributedSize; - RDDStats joinedRDD = new RDDStats(joinedSize, -1); - dataTransmissionTime = getSparkShuffleTime(joinedRDD, executorMetrics, true); - // estimate for out.combineByKey() per partition - dataTransmissionTime += getSparkShuffleTime(output.rddStats, executorMetrics, false); - output.rddStats.hashPartitioned = true; - numPartitionsForMapping = joinedRDD.numPartitions; - } else if (inst instanceof MapmmSPInstruction) { - dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics); - MapmmSPInstruction mapmminst = (MapmmSPInstruction) inst; - AggBinaryOp.SparkAggType aggType = mapmminst.getAggType(); - if (aggType == AggBinaryOp.SparkAggType.SINGLE_BLOCK) { - dataTransmissionTime += getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics); - output.rddStats.isCollected = true; - } else { - dataTransmissionTime += getSparkShuffleTime(output.rddStats, executorMetrics, true); - output.rddStats.hashPartitioned = true; - } - numPartitionsForMapping = input1.rddStats.numPartitions; - } else if (inst instanceof PmmSPInstruction) { - dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics); - output.rddStats.numPartitions = input1.rddStats.numPartitions; - dataTransmissionTime += getSparkShuffleTime(output.rddStats, executorMetrics, true); - output.rddStats.hashPartitioned = true; - numPartitionsForMapping = input1.rddStats.numPartitions; - } else if (inst instanceof ZipmmSPInstruction) { - // assume always a shuffle without data re-distribution - dataTransmissionTime = getSparkShuffleTime(output.rddStats, executorMetrics, false); - dataTransmissionTime += getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics); - numPartitionsForMapping = input1.rddStats.numPartitions; - output.rddStats.isCollected = true; - } else if (inst instanceof PMapmmSPInstruction) { - throw new RuntimeException("PMapmmSPInstruction instruction is still experimental and not supported yet"); - } else { - throw new RuntimeException(inst.getClass().getName() + " instruction is not handled by the current method"); - } - long nflop = getInstNFLOP(inst.getSPInstructionType(), inst.getOpcode(), output, input1, input2); - double mapTime; - if (inst instanceof MapmmSPInstruction || inst instanceof PmmSPInstruction) { - // scan only first input - mapTime = getCPUTime(nflop, numPartitionsForMapping, executorMetrics, output.rddStats, input1.rddStats); - } else { - mapTime = getCPUTime(nflop, numPartitionsForMapping, executorMetrics, output.rddStats, input1.rddStats, input2.rddStats); - } - return dataTransmissionTime + mapTime; - } - - public static double getMatMulChainInstTime(MapmmChainSPInstruction inst, VarStats input1, VarStats input2, VarStats input3, VarStats output, - IOMetrics driverMetrics, IOMetrics executorMetrics) { - double dataTransmissionTime = 0; - if (input3 != null) { - dataTransmissionTime += getSparkBroadcastTime(input3, driverMetrics, executorMetrics); - } - dataTransmissionTime += getSparkBroadcastTime(input2, driverMetrics, executorMetrics); - dataTransmissionTime += getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics); - output.rddStats.isCollected = true; - - long nflop = getInstNFLOP(SPType.MAPMMCHAIN, inst.getOpcode(), output, input1, input2); - double mapTime = getCPUTime(nflop, input1.rddStats.numPartitions, executorMetrics, output.rddStats, input1.rddStats); - return dataTransmissionTime + mapTime; - } - - public static double getCtableInstTime(CtableSPInstruction tableInst, VarStats input1, VarStats input2, VarStats input3, VarStats output, IOMetrics executorMetrics) { - String opcode = tableInst.getOpcode(); - double shuffleTime; - if (opcode.equals("ctableexpand") || !input2.isScalar() && input3.isScalar()) { // CTABLE_EXPAND_SCALAR_WEIGHT/CTABLE_TRANSFORM_SCALAR_WEIGHT - // in1.join(in2) - shuffleTime = getSparkShuffleTime(input2.rddStats, executorMetrics, true); - } else if (input2.isScalar() && input3.isScalar()) { // CTABLE_TRANSFORM_HISTOGRAM - // no joins - shuffleTime = 0; - } else if (input2.isScalar() && !input3.isScalar()) { // CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM - // in1.join(in3) - shuffleTime = getSparkShuffleTime(input3.rddStats, executorMetrics, true); - } else { // CTABLE_TRANSFORM - // in1.join(in2).join(in3) - shuffleTime = getSparkShuffleTime(input2.rddStats, executorMetrics, true); - shuffleTime += getSparkShuffleTime(input3.rddStats, executorMetrics, true); - } - // combineByKey() - shuffleTime += getSparkShuffleTime(output.rddStats, executorMetrics, true); - output.rddStats.hashPartitioned = true; - - long nflop = getInstNFLOP(SPType.Ctable, opcode, output, input1, input2, input3); - double mapTime = getCPUTime(nflop, output.rddStats.numPartitions, executorMetrics, - output.rddStats, input1.rddStats, input2.rddStats, input3.rddStats); - - return shuffleTime + mapTime; - } - - public static double getParameterizedBuiltinInstTime(ParameterizedBuiltinSPInstruction paramInst, VarStats input1, VarStats input2, VarStats output, IOMetrics driverMetrics, IOMetrics executorMetrics) { - String opcode = paramInst.getOpcode(); - double dataTransmissionTime; - switch (opcode) { - case "rmempty": - if (input2.rddStats == null) // broadcast - dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics); - else // join - dataTransmissionTime = getSparkShuffleTime(input1.rddStats, executorMetrics, true); - dataTransmissionTime += getSparkShuffleTime(output.rddStats, executorMetrics, true); - break; - case "contains": - if (input2.isScalar()) { - dataTransmissionTime = 0; - } else { - dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics); - // ignore reduceByKey() cost - } - output.rddStats.isCollected = true; - break; - case "replace": - case "lowertri": - case "uppertri": - dataTransmissionTime = 0; - break; - default: - throw new RuntimeException("Spark operation ParameterizedBuiltin with opcode " + opcode + " is not supported yet"); - } - - long nflop = getInstNFLOP(paramInst.getSPInstructionType(), opcode, output, input1); - double mapTime = getCPUTime(nflop, input1.rddStats.numPartitions, executorMetrics, output.rddStats, input1.rddStats); - - return dataTransmissionTime + mapTime; - } - - /** - * Computes an estimate for the time needed by the CPU to execute (including memory access) - * an instruction by providing number of floating operations. - * - * @param nflop number FLOP to execute a target CPU operation - * @param numPartitions number partitions used to execute the target operation; - * not bound to any of the input/output statistics object to allow more - * flexibility depending on the corresponding instruction - * @param executorMetrics metrics for the executor utilized by the Spark cluster - * @param output statistics for the output variable - * @param inputs arrays of statistics for the output variable - * @return time estimate - */ - public static double getCPUTime(long nflop, int numPartitions, IOMetrics executorMetrics, RDDStats output, RDDStats...inputs) { - double memScanTime = 0; - for (RDDStats input: inputs) { - if (input == null) continue; - // compensates for spill-overs to account for non-compute bound operations - memScanTime += getMemReadTime(input, executorMetrics); - } - double numWaves = Math.ceil((double) numPartitions / SparkExecutionContext.getDefaultParallelism(false)); - double scaledNFLOP = (numWaves * nflop) / numPartitions; - double cpuComputationTime = scaledNFLOP / executorMetrics.cpuFLOPS; - double memWriteTime = output != null? getMemWriteTime(output, executorMetrics) : 0; - return Math.max(memScanTime, cpuComputationTime) + memWriteTime; - } - - public static void assignOutputRDDStats(SPInstruction inst, VarStats output, VarStats...inputs) { - if (!output.isScalar()) { - SPType instType = inst.getSPInstructionType(); - String opcode = inst.getOpcode(); - if (output.getCells() < 0) { - inferStats(instType, opcode, output, inputs); - } - } - output.rddStats = new RDDStats(output); - } - - private static void inferStats(SPType instType, String opcode, VarStats output, VarStats...inputs) { - switch (instType) { - case Unary: - case Builtin: - CPCostUtils.inferStats(CPType.Unary, opcode, output, inputs); - break; - case AggregateUnary: - case AggregateUnarySketch: - CPCostUtils.inferStats(CPType.AggregateUnary, opcode, output, inputs); - case MatrixIndexing: - CPCostUtils.inferStats(CPType.MatrixIndexing, opcode, output, inputs); - break; - case Reorg: - CPCostUtils.inferStats(CPType.Reorg, opcode, output, inputs); - break; - case Binary: - CPCostUtils.inferStats(CPType.Binary, opcode, output, inputs); - break; - case CPMM: - case RMM: - case MAPMM: - case PMM: - case ZIPMM: - CPCostUtils.inferStats(CPType.AggregateBinary, opcode, output, inputs); - break; - case ParameterizedBuiltin: - CPCostUtils.inferStats(CPType.ParameterizedBuiltin, opcode, output, inputs); - break; - case Rand: - CPCostUtils.inferStats(CPType.Rand, opcode, output, inputs); - break; - case Ctable: - CPCostUtils.inferStats(CPType.Ctable, opcode, output, inputs); - break; - default: - throw new RuntimeException("Operation of type "+instType+" with opcode '"+opcode+"' has no formula for inferring dimensions"); - } - if (output.getCells() < 0) { - throw new RuntimeException("Operation of type "+instType+" with opcode '"+opcode+"' has incomplete formula for inferring dimensions"); - } - } - - private static long getInstNFLOP( - SPType instructionType, - String opcode, - VarStats output, - VarStats...inputs - ) { - opcode = opcode.toLowerCase(); - double costs; - switch (instructionType) { - case Reblock: - if (opcode.startsWith("libsvm")) { - return output.getCellsWithSparsity(); - } else { // starts with "rblk" or "csvrblk" - return output.getCells(); - } - case Unary: - case Builtin: - return CPCostUtils.getInstNFLOP(CPType.Unary, opcode, output, inputs); - case AggregateUnary: - case AggregateUnarySketch: - switch (opcode) { - case "uacdr": - case "uacdc": - throw new DMLRuntimeException(opcode + " opcode is not implemented by SystemDS"); - default: - return CPCostUtils.getInstNFLOP(CPType.AggregateUnary, opcode, output, inputs); - } - case CumsumAggregate: - switch (opcode) { - case "ucumack+": - case "ucumac*": - case "ucumacmin": - case "ucumacmax": - costs = 1; break; - case "ucumac+*": - costs = 2; break; - default: - throw new DMLRuntimeException(opcode + " opcode is not implemented by SystemDS"); - } - return (long) (costs * inputs[0].getCells() + costs * output.getN()); - case TSMM: - case TSMM2: - return CPCostUtils.getInstNFLOP(CPType.MMTSJ, opcode, output, inputs); - case Reorg: - case MatrixReshape: - return CPCostUtils.getInstNFLOP(CPType.Reorg, opcode, output, inputs); - case MatrixIndexing: - // the actual opcode value is not used at the moment - return CPCostUtils.getInstNFLOP(CPType.MatrixIndexing, opcode, output, inputs); - case Cast: - return output.getCellsWithSparsity(); - case QSort: - return CPCostUtils.getInstNFLOP(CPType.QSort, opcode, output, inputs); - case CentralMoment: - return CPCostUtils.getInstNFLOP(CPType.CentralMoment, opcode, output, inputs); - case UaggOuterChain: - case Dnn: - throw new RuntimeException("Spark operation type'" + instructionType + "' is not supported yet"); - // types corresponding to BinaryCPInstruction - case Binary: - switch (opcode) { - case "+*": - case "-*": - // original "map+*" and "map-*" - // "+*" and "-*" defined as ternary - throw new RuntimeException("Spark operation with opcode '" + opcode + "' is not supported yet"); - default: - return CPCostUtils.getInstNFLOP(CPType.Binary, opcode, output, inputs); - } - case CPMM: - case RMM: - case MAPMM: - case PMM: - case ZIPMM: - case PMAPMM: - // do not reduce by factor of 2: not explicit matrix multiplication - return 2 * CPCostUtils.getInstNFLOP(CPType.AggregateBinary, opcode, output, inputs); - case MAPMMCHAIN: - return 2 * inputs[0].getCells() * inputs[0].getN() // ba(+*) - + 2 * inputs[0].getM() * inputs[1].getN() // cellwise b(*) + r(t) - + 2 * inputs[0].getCellsWithSparsity() * inputs[1].getN() // ba(+*) - + inputs[1].getM() * output.getM() ; //r(t) - case BinUaggChain: - break; - case MAppend: - case RAppend: - case GAppend: - case GAlignedAppend: - // the actual opcode value is not used at the moment - return CPCostUtils.getInstNFLOP(CPType.Append, opcode, output, inputs); - case BuiltinNary: - return CPCostUtils.getInstNFLOP(CPType.BuiltinNary, opcode, output, inputs); - case Ctable: - return CPCostUtils.getInstNFLOP(CPType.Ctable, opcode, output, inputs); - case ParameterizedBuiltin: - return CPCostUtils.getInstNFLOP(CPType.ParameterizedBuiltin, opcode, output, inputs); - default: - // all existing cases should have been handled above - throw new DMLRuntimeException("Spark operation type'" + instructionType + "' is not supported by SystemDS"); - } - throw new RuntimeException(); - } - - -// //ternary aggregate operators -// case "tak+*": -// break; -// case "tack+*": -// break; -// // Neural network operators -// case "conv2d": -// case "conv2d_bias_add": -// case "maxpooling": -// case "relu_maxpooling": -// case RightIndex.OPCODE: -// case LeftIndex.OPCODE: -// case "mapLeftIndex": -// case "_map",: -// break; -// // Spark-specific instructions -// case Checkpoint.DEFAULT_CP_OPCODE,: -// break; -// case Checkpoint.ASYNC_CP_OPCODE,: -// break; -// case Compression.OPCODE,: -// break; -// case DeCompression.OPCODE,: -// break; -// // Parameterized Builtin Functions -// case "autoDiff",: -// break; -// case "contains",: -// break; -// case "groupedagg",: -// break; -// case "mapgroupedagg",: -// break; -// case "rmempty",: -// break; -// case "replace",: -// break; -// case "rexpand",: -// break; -// case "lowertri",: -// break; -// case "uppertri",: -// break; -// case "tokenize",: -// break; -// case "transformapply",: -// break; -// case "transformdecode",: -// break; -// case "transformencode",: -// break; -// case "mappend",: -// break; -// case "rappend",: -// break; -// case "gappend",: -// break; -// case "galignedappend",: -// break; -// //ternary instruction opcodes -// case "ctable",: -// break; -// case "ctableexpand",: -// break; + public static double getReblockInstTime(String opcode, VarStats input, VarStats output, IOMetrics executorMetrics) { + // Reblock triggers a new stage + // old stage: read text file + shuffle the intermediate text rdd + double readTime = getHadoopReadTime(input, executorMetrics); + long sizeTextFile = OptimizerUtils.estimateSizeTextOutput(input.getM(), input.getN(), input.getNNZ(), (Types.FileFormat) input.fileInfo[1]); + RDDStats textRdd = new RDDStats(sizeTextFile, -1); + double shuffleTime = getSparkShuffleTime(textRdd, executorMetrics, true); + double timeStage1 = readTime + shuffleTime; + // new stage: transform partitioned shuffled text object into partitioned binary object + long nflop = getInstNFLOP(SPType.Reblock, opcode, output); + double timeStage2 = getCPUTime(nflop, textRdd.numPartitions, executorMetrics, output.rddStats, textRdd); + return timeStage1 + timeStage2; + } + + public static double getRandInstTime(String opcode, int randType, VarStats output, IOMetrics executorMetrics) { + if (opcode.equals(SAMPLE_OPCODE)) { + // sample uses sortByKey() op. and it should be handled differently + throw new RuntimeException("Spark operation Rand with opcode " + SAMPLE_OPCODE + " is not supported yet"); + } + + long nflop; + if (opcode.equals(RAND_OPCODE) || opcode.equals(FRAME_OPCODE)) { + if (randType == 0) return 0; // empty matrix + else if (randType == 1) nflop = 8; // allocate, array fill + else if (randType == 2) nflop = 32; // full rand + else throw new RuntimeException("Unknown type of random instruction"); + } else if (opcode.equals(SEQ_OPCODE)) { + nflop = 1; + } else { + throw new DMLRuntimeException("Rand operation with opcode '" + opcode + "' is not supported by SystemDS"); + } + nflop *= output.getCells(); + // no shuffling required -> only computation time + return getCPUTime(nflop, output.rddStats.numPartitions, executorMetrics, output.rddStats); + } + + public static double getUnaryInstTime(String opcode, VarStats input, VarStats output, IOMetrics executorMetrics) { + // handles operations of type Builtin as Unary + // Unary adds a map() to an open stage + long nflop = getInstNFLOP(SPType.Unary, opcode, output, input); + double mapTime = getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, input.rddStats); + // the resulting rdd is being hash-partitioned depending on the input one + output.rddStats.hashPartitioned = input.rddStats.hashPartitioned; + return mapTime; + } + + public static double getAggUnaryInstTime(UnarySPInstruction inst, VarStats input, VarStats output, IOMetrics executorMetrics) { + // AggregateUnary results in different Spark execution plan depending on the output dimensions + String opcode = inst.getOpcode(); + AggBinaryOp.SparkAggType aggType = (inst instanceof AggregateUnarySPInstruction)? + ((AggregateUnarySPInstruction) inst).getAggType(): + ((AggregateUnarySketchSPInstruction) inst).getAggType(); + double shuffleTime; + if (inst instanceof CumulativeAggregateSPInstruction) { + shuffleTime = getSparkShuffleTime(output.rddStats, executorMetrics, true); + output.rddStats.hashPartitioned = true; + } else { + if (aggType == AggBinaryOp.SparkAggType.SINGLE_BLOCK) { + // loading RDD to the driver (CP) explicitly (not triggered by CP instruction) + output.rddStats.isCollected = true; + // cost for transferring result values (result from fold()) is negligible -> cost = computation time + shuffleTime = 0; + } else if (aggType == AggBinaryOp.SparkAggType.MULTI_BLOCK) { + // combineByKey() triggers a new stage -> cost = computation time + shuffle time (combineByKey); + if (opcode.equals("uaktrace")) { + long diagonalBlockSize = OptimizerUtils.estimatePartitionedSizeExactSparsity( + input.characteristics.getBlocksize() * input.getM(), + input.characteristics.getBlocksize(), + input.characteristics.getBlocksize(), + input.getNNZ() + ); + RDDStats filteredRDD = new RDDStats(diagonalBlockSize, input.rddStats.numPartitions); + shuffleTime = getSparkShuffleTime(filteredRDD, executorMetrics, true); + } else { + shuffleTime = getSparkShuffleTime(input.rddStats, executorMetrics, true); + } + output.rddStats.hashPartitioned = true; + output.rddStats.numPartitions = input.rddStats.numPartitions; + } else { // aggType == AggBinaryOp.SparkAggType.NONE + output.rddStats.hashPartitioned = input.rddStats.hashPartitioned; + output.rddStats.numPartitions = input.rddStats.numPartitions; + // only mapping transformation -> cost = computation time + shuffleTime = 0; + } + } + long nflop = getInstNFLOP(SPType.AggregateUnary, opcode, output, input); + double mapTime = getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, input.rddStats); + return shuffleTime + mapTime; + } + + public static double getIndexingInstTime(IndexingSPInstruction inst, VarStats input1, VarStats input2, VarStats output, IOMetrics driverMetrics, IOMetrics executorMetrics) { + String opcode = inst.getOpcode(); + double dataTransmissionTime; + if (opcode.equals(RightIndex.OPCODE)) { + // assume direct collecting if output dimensions not larger than block size + int blockSize = ConfigurationManager.getBlocksize(); + if (output.getM() <= blockSize && output.getN() <= blockSize) { + // represents single block and multi block cases + dataTransmissionTime = getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics); + output.rddStats.isCollected = true; + } else { + // represents general indexing: worst case: shuffling required + dataTransmissionTime = getSparkShuffleTime(output.rddStats, executorMetrics, true); + } + } else if (opcode.equals(LeftIndex.OPCODE)) { + // model combineByKey() with shuffling the second input + dataTransmissionTime = getSparkShuffleTime(input2.rddStats, executorMetrics, true); + } else { // mapLeftIndex + dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics); + } + long nflop = getInstNFLOP(SPType.MatrixIndexing, opcode, output); + // scan only the size of the output since filter is applied first + RDDStats[] objectsToScan = (input2 == null)? new RDDStats[]{output.rddStats} : + new RDDStats[]{output.rddStats, output.rddStats}; + double mapTime = getCPUTime(nflop, input1.rddStats.numPartitions, executorMetrics, output.rddStats, objectsToScan); + return dataTransmissionTime + mapTime; + } + + public static double getBinaryInstTime(SPInstruction inst, VarStats input1, VarStats input2, VarStats output, IOMetrics driverMetrics, IOMetrics executorMetrics) { + SPType opType = inst.getSPInstructionType(); + String opcode = inst.getOpcode(); + // binary, builtin binary (log and log_nz) + // for the NFLOP calculation if the function is executed as map is not relevant + if (opcode.startsWith("map")) { + opcode = opcode.substring(3); + } + double dataTransmissionTime; + if (inst instanceof BinaryMatrixMatrixSPInstruction) { + if (inst instanceof BinaryMatrixBVectorSPInstruction) { + // the second matrix is always the broadcast one + dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics); + // flatMapToPair() or ()mapPartitionsToPair invoked -> no shuffling + output.rddStats.numPartitions = input1.rddStats.numPartitions; + output.rddStats.hashPartitioned = input1.rddStats.hashPartitioned; + } else { // regular BinaryMatrixMatrixSPInstruction + // join() input1 and input2 + dataTransmissionTime = getSparkShuffleWriteTime(input1.rddStats, executorMetrics) + + getSparkShuffleWriteTime(input2.rddStats, executorMetrics); + if (input1.rddStats.hashPartitioned) { + output.rddStats.numPartitions = input1.rddStats.numPartitions; + if (!input2.rddStats.hashPartitioned || !(input1.rddStats.numPartitions == input2.rddStats.numPartitions)) { + // shuffle needed for join() -> actual shuffle only for input2 + dataTransmissionTime += getSparkShuffleReadStaticTime(input1.rddStats, executorMetrics) + + getSparkShuffleReadTime(input2.rddStats, executorMetrics); + } else { // no shuffle needed for join() -> only read from local disk + dataTransmissionTime += getSparkShuffleReadStaticTime(input1.rddStats, executorMetrics) + + getSparkShuffleReadStaticTime(input2.rddStats, executorMetrics); + } + } else if (input2.rddStats.hashPartitioned) { + output.rddStats.numPartitions = input2.rddStats.numPartitions; + // input1 not hash partitioned: shuffle needed for join() -> actual shuffle only for input2 + dataTransmissionTime += getSparkShuffleReadStaticTime(input1.rddStats, executorMetrics) + + getSparkShuffleReadTime(input2.rddStats, executorMetrics); + } else { + // repartition all data needed + output.rddStats.numPartitions = 2 * output.rddStats.numPartitions; + dataTransmissionTime += getSparkShuffleReadTime(input1.rddStats, executorMetrics) + + getSparkShuffleReadTime(input2.rddStats, executorMetrics); + } + output.rddStats.hashPartitioned = true; + } + } else if (inst instanceof BinaryMatrixScalarSPInstruction) { + // only mapValues() invoked -> no shuffling + dataTransmissionTime = 0; + output.rddStats.hashPartitioned = (input2.isScalar())? input1.rddStats.hashPartitioned : input2.rddStats.hashPartitioned; + } else if (inst instanceof BinaryFrameMatrixSPInstruction || inst instanceof BinaryFrameFrameSPInstruction) { + throw new RuntimeException("Handling binary instructions for frames not handled yet."); + } else { + throw new RuntimeException("Not supported binary instruction: "+inst); + } + long nflop = getInstNFLOP(opType, opcode, output, input1, input2); + double mapTime = getCPUTime(nflop, output.rddStats.numPartitions, executorMetrics, output.rddStats, input1.rddStats, input2.rddStats); + return dataTransmissionTime + mapTime; + } + + public static double getAppendInstTime(AppendSPInstruction inst, VarStats input1, VarStats input2, VarStats output, IOMetrics driverMetrics, IOMetrics executorMetrics) { + double dataTransmissionTime; + if (inst instanceof AppendMSPInstruction) { + dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics); + output.rddStats.hashPartitioned = true; + } else if (inst instanceof AppendRSPInstruction) { + dataTransmissionTime = getSparkShuffleTime(output.rddStats, executorMetrics, false); + } else if (inst instanceof AppendGAlignedSPInstruction) { + // only changing matrix indexing + dataTransmissionTime = 0; + } else { // AppendGSPInstruction + // shuffle the whole appended matrix + dataTransmissionTime = getSparkShuffleTime(input2.rddStats, executorMetrics, true); + output.rddStats.hashPartitioned = true; + } + // opcode not relevant for the nflop estimation of append instructions; + long nflop = getInstNFLOP(inst.getSPInstructionType(), "append", output, input1, input2); + double mapTime = getCPUTime(nflop, output.rddStats.numPartitions, executorMetrics, output.rddStats, input1.rddStats, input2.rddStats); + return dataTransmissionTime + mapTime; + } + + public static double getReorgInstTime(UnarySPInstruction inst, VarStats input, VarStats output, IOMetrics executorMetrics) { + // includes logic for MatrixReshapeSPInstruction + String opcode = inst.getOpcode(); + double dataTransmissionTime; + switch (opcode) { + case "rshape": + dataTransmissionTime = getSparkShuffleTime(input.rddStats, executorMetrics, true); + output.rddStats.hashPartitioned = true; + break; + case "r'": + dataTransmissionTime = 0; + output.rddStats.hashPartitioned = input.rddStats.hashPartitioned; + break; + case "rev": + dataTransmissionTime = getSparkShuffleTime(output.rddStats, executorMetrics, true); + output.rddStats.hashPartitioned = true; + break; + case "rdiag": + dataTransmissionTime = 0; + output.rddStats.numPartitions = input.rddStats.numPartitions; + output.rddStats.hashPartitioned = input.rddStats.hashPartitioned; + break; + default: // rsort + String ixretAsString = InstructionUtils.getInstructionParts(inst.getInstructionString())[4]; + boolean ixret = ixretAsString.equalsIgnoreCase("true"); + int shuffleFactor; + if (ixret) { // index return + shuffleFactor = 2; // estimate cost for 2 shuffles + } else { + shuffleFactor = 4;// estimate cost for 2 shuffles + } + // assume case: 4 times shuffling the output + dataTransmissionTime = getSparkShuffleWriteTime(output.rddStats, executorMetrics) + + getSparkShuffleReadTime(output.rddStats, executorMetrics); + dataTransmissionTime *= shuffleFactor; + break; + } + long nflop = getInstNFLOP(inst.getSPInstructionType(), opcode, output); // uses output only + double mapTime = getCPUTime(nflop, output.rddStats.numPartitions, executorMetrics, output.rddStats, input.rddStats); + return dataTransmissionTime + mapTime; + } + + public static double getTSMMInstTime(UnarySPInstruction inst, VarStats input, VarStats output, IOMetrics driverMetrics, IOMetrics executorMetrics) { + String opcode = inst.getOpcode(); + MMTSJ.MMTSJType type; + + double dataTransmissionTime; + if (inst instanceof TsmmSPInstruction) { + type = ((TsmmSPInstruction) inst).getMMTSJType(); + // fold() used but result is still a whole matrix block + dataTransmissionTime = getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics); + output.rddStats.isCollected = true; + } else { // Tsmm2SPInstruction + type = ((Tsmm2SPInstruction) inst).getMMTSJType(); + // assumes always default output with collect + long rowsRange = (type == MMTSJ.MMTSJType.LEFT)? input.getM() : + input.getM() - input.characteristics.getBlocksize(); + long colsRange = (type != MMTSJ.MMTSJType.LEFT)? input.getN() : + input.getN() - input.characteristics.getBlocksize(); + VarStats broadcast = new VarStats("tmp1", new MatrixCharacteristics(rowsRange, colsRange)); + broadcast.rddStats = new RDDStats(broadcast); + dataTransmissionTime = getSparkCollectTime(broadcast.rddStats, driverMetrics, executorMetrics); + dataTransmissionTime += getSparkBroadcastTime(broadcast, driverMetrics, executorMetrics); + dataTransmissionTime += getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics); + } + opcode += type.isLeft() ? "_left" : "_right"; + long nflop = getInstNFLOP(inst.getSPInstructionType(), opcode, output, input); + double mapTime = getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, input.rddStats); + return dataTransmissionTime + mapTime; + } + + public static double getCentralMomentInstTime(CentralMomentSPInstruction inst, VarStats input, VarStats weights, VarStats output, IOMetrics executorMetrics) { + CMOperator.AggregateOperationTypes opType = ((CMOperator) inst.getOperator()).getAggOpType(); + String opcode = inst.getOpcode() + "_" + opType.name().toLowerCase(); + + double dataTransmissionTime = 0; + if (weights != null) { + dataTransmissionTime = getSparkShuffleWriteTime(weights.rddStats, executorMetrics) + + getSparkShuffleReadTime(weights.rddStats, executorMetrics); + + } + output.rddStats.isCollected = true; + + RDDStats[] RDDInputs = (weights == null)? new RDDStats[]{input.rddStats} : new RDDStats[]{input.rddStats, weights.rddStats}; + long nflop = getInstNFLOP(inst.getSPInstructionType(), opcode, output, input); + double mapTime = getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, RDDInputs); + return dataTransmissionTime + mapTime; + } + + public static double getCastInstTime(CastSPInstruction inst, VarStats input, VarStats output, IOMetrics executorMetrics) { + double shuffleTime = 0; + if (input.getN() > input.characteristics.getBlocksize()) { + shuffleTime = getSparkShuffleWriteTime(input.rddStats, executorMetrics) + + getSparkShuffleReadTime(input.rddStats, executorMetrics); + output.rddStats.hashPartitioned = true; + } + long nflop = getInstNFLOP(inst.getSPInstructionType(), inst.getOpcode(), output, input); + double mapTime = getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, input.rddStats); + return shuffleTime + mapTime; + } + + public static double getQSortInstTime(QuantileSortSPInstruction inst, VarStats input, VarStats weights, VarStats output, IOMetrics executorMetrics) { + String opcode = inst.getOpcode(); + double shuffleTime = 0; + if (weights != null) { + opcode += "_wts"; + shuffleTime += getSparkShuffleWriteTime(weights.rddStats, executorMetrics) + + getSparkShuffleReadTime(weights.rddStats, executorMetrics); + } + shuffleTime += getSparkShuffleWriteTime(output.rddStats, executorMetrics) + + getSparkShuffleReadTime(output.rddStats, executorMetrics); + output.rddStats.hashPartitioned = true; + + long nflop = getInstNFLOP(SPType.QSort, opcode, output, input, weights); + RDDStats[] RDDInputs = (weights == null)? new RDDStats[]{input.rddStats} : new RDDStats[]{input.rddStats, weights.rddStats}; + double mapTime = getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, RDDInputs); + return shuffleTime + mapTime; + } + + public static double getMatMulInstTime(BinarySPInstruction inst, VarStats input1, VarStats input2, VarStats output, IOMetrics driverMetrics, IOMetrics executorMetrics) { + double dataTransmissionTime; + int numPartitionsForMapping; + if (inst instanceof CpmmSPInstruction) { + CpmmSPInstruction cpmminst = (CpmmSPInstruction) inst; + AggBinaryOp.SparkAggType aggType = cpmminst.getAggType(); + // estimate for in1.join(in2) + long joinedSize = input1.rddStats.distributedSize + input2.rddStats.distributedSize; + RDDStats joinedRDD = new RDDStats(joinedSize, -1); + dataTransmissionTime = getSparkShuffleTime(joinedRDD, executorMetrics, true); + if (aggType == AggBinaryOp.SparkAggType.SINGLE_BLOCK) { + dataTransmissionTime += getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics); + output.rddStats.isCollected = true; + } else { + dataTransmissionTime += getSparkShuffleTime(output.rddStats, executorMetrics, true); + output.rddStats.hashPartitioned = true; + } + numPartitionsForMapping = joinedRDD.numPartitions; + } else if (inst instanceof RmmSPInstruction) { + // estimate for in1.join(in2) + long joinedSize = input1.rddStats.distributedSize + input2.rddStats.distributedSize; + RDDStats joinedRDD = new RDDStats(joinedSize, -1); + dataTransmissionTime = getSparkShuffleTime(joinedRDD, executorMetrics, true); + // estimate for out.combineByKey() per partition + dataTransmissionTime += getSparkShuffleTime(output.rddStats, executorMetrics, false); + output.rddStats.hashPartitioned = true; + numPartitionsForMapping = joinedRDD.numPartitions; + } else if (inst instanceof MapmmSPInstruction) { + dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics); + MapmmSPInstruction mapmminst = (MapmmSPInstruction) inst; + AggBinaryOp.SparkAggType aggType = mapmminst.getAggType(); + if (aggType == AggBinaryOp.SparkAggType.SINGLE_BLOCK) { + dataTransmissionTime += getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics); + output.rddStats.isCollected = true; + } else { + dataTransmissionTime += getSparkShuffleTime(output.rddStats, executorMetrics, true); + output.rddStats.hashPartitioned = true; + } + numPartitionsForMapping = input1.rddStats.numPartitions; + } else if (inst instanceof PmmSPInstruction) { + dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics); + output.rddStats.numPartitions = input1.rddStats.numPartitions; + dataTransmissionTime += getSparkShuffleTime(output.rddStats, executorMetrics, true); + output.rddStats.hashPartitioned = true; + numPartitionsForMapping = input1.rddStats.numPartitions; + } else if (inst instanceof ZipmmSPInstruction) { + // assume always a shuffle without data re-distribution + dataTransmissionTime = getSparkShuffleTime(output.rddStats, executorMetrics, false); + dataTransmissionTime += getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics); + numPartitionsForMapping = input1.rddStats.numPartitions; + output.rddStats.isCollected = true; + } else if (inst instanceof PMapmmSPInstruction) { + throw new RuntimeException("PMapmmSPInstruction instruction is still experimental and not supported yet"); + } else { + throw new RuntimeException(inst.getClass().getName() + " instruction is not handled by the current method"); + } + long nflop = getInstNFLOP(inst.getSPInstructionType(), inst.getOpcode(), output, input1, input2); + double mapTime; + if (inst instanceof MapmmSPInstruction || inst instanceof PmmSPInstruction) { + // scan only first input + mapTime = getCPUTime(nflop, numPartitionsForMapping, executorMetrics, output.rddStats, input1.rddStats); + } else { + mapTime = getCPUTime(nflop, numPartitionsForMapping, executorMetrics, output.rddStats, input1.rddStats, input2.rddStats); + } + return dataTransmissionTime + mapTime; + } + + public static double getMatMulChainInstTime(MapmmChainSPInstruction inst, VarStats input1, VarStats input2, VarStats input3, VarStats output, + IOMetrics driverMetrics, IOMetrics executorMetrics) { + double dataTransmissionTime = 0; + if (input3 != null) { + dataTransmissionTime += getSparkBroadcastTime(input3, driverMetrics, executorMetrics); + } + dataTransmissionTime += getSparkBroadcastTime(input2, driverMetrics, executorMetrics); + dataTransmissionTime += getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics); + output.rddStats.isCollected = true; + + long nflop = getInstNFLOP(SPType.MAPMMCHAIN, inst.getOpcode(), output, input1, input2); + double mapTime = getCPUTime(nflop, input1.rddStats.numPartitions, executorMetrics, output.rddStats, input1.rddStats); + return dataTransmissionTime + mapTime; + } + + public static double getCtableInstTime(CtableSPInstruction tableInst, VarStats input1, VarStats input2, VarStats input3, VarStats output, IOMetrics executorMetrics) { + String opcode = tableInst.getOpcode(); + double shuffleTime; + if (opcode.equals("ctableexpand") || !input2.isScalar() && input3.isScalar()) { // CTABLE_EXPAND_SCALAR_WEIGHT/CTABLE_TRANSFORM_SCALAR_WEIGHT + // in1.join(in2) + shuffleTime = getSparkShuffleTime(input2.rddStats, executorMetrics, true); + } else if (input2.isScalar() && input3.isScalar()) { // CTABLE_TRANSFORM_HISTOGRAM + // no joins + shuffleTime = 0; + } else if (input2.isScalar() && !input3.isScalar()) { // CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM + // in1.join(in3) + shuffleTime = getSparkShuffleTime(input3.rddStats, executorMetrics, true); + } else { // CTABLE_TRANSFORM + // in1.join(in2).join(in3) + shuffleTime = getSparkShuffleTime(input2.rddStats, executorMetrics, true); + shuffleTime += getSparkShuffleTime(input3.rddStats, executorMetrics, true); + } + // combineByKey() + shuffleTime += getSparkShuffleTime(output.rddStats, executorMetrics, true); + output.rddStats.hashPartitioned = true; + + long nflop = getInstNFLOP(SPType.Ctable, opcode, output, input1, input2, input3); + double mapTime = getCPUTime(nflop, output.rddStats.numPartitions, executorMetrics, + output.rddStats, input1.rddStats, input2.rddStats, input3.rddStats); + + return shuffleTime + mapTime; + } + + public static double getParameterizedBuiltinInstTime(ParameterizedBuiltinSPInstruction paramInst, VarStats input1, VarStats input2, VarStats output, IOMetrics driverMetrics, IOMetrics executorMetrics) { + String opcode = paramInst.getOpcode(); + double dataTransmissionTime; + switch (opcode) { + case "rmempty": + if (input2.rddStats == null) // broadcast + dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics); + else // join + dataTransmissionTime = getSparkShuffleTime(input1.rddStats, executorMetrics, true); + dataTransmissionTime += getSparkShuffleTime(output.rddStats, executorMetrics, true); + break; + case "contains": + if (input2.isScalar()) { + dataTransmissionTime = 0; + } else { + dataTransmissionTime = getSparkBroadcastTime(input2, driverMetrics, executorMetrics); + // ignore reduceByKey() cost + } + output.rddStats.isCollected = true; + break; + case "replace": + case "lowertri": + case "uppertri": + dataTransmissionTime = 0; + break; + default: + throw new RuntimeException("Spark operation ParameterizedBuiltin with opcode " + opcode + " is not supported yet"); + } + + long nflop = getInstNFLOP(paramInst.getSPInstructionType(), opcode, output, input1); + double mapTime = getCPUTime(nflop, input1.rddStats.numPartitions, executorMetrics, output.rddStats, input1.rddStats); + + return dataTransmissionTime + mapTime; + } + + /** + * Computes an estimate for the time needed by the CPU to execute (including memory access) + * an instruction by providing number of floating operations. + * + * @param nflop number FLOP to execute a target CPU operation + * @param numPartitions number partitions used to execute the target operation; + * not bound to any of the input/output statistics object to allow more + * flexibility depending on the corresponding instruction + * @param executorMetrics metrics for the executor utilized by the Spark cluster + * @param output statistics for the output variable + * @param inputs arrays of statistics for the output variable + * @return time estimate + */ + public static double getCPUTime(long nflop, int numPartitions, IOMetrics executorMetrics, RDDStats output, RDDStats...inputs) { + double memScanTime = 0; + for (RDDStats input: inputs) { + if (input == null) continue; + // compensates for spill-overs to account for non-compute bound operations + memScanTime += getMemReadTime(input, executorMetrics); + } + double numWaves = Math.ceil((double) numPartitions / SparkExecutionContext.getDefaultParallelism(false)); + double scaledNFLOP = (numWaves * nflop) / numPartitions; + double cpuComputationTime = scaledNFLOP / executorMetrics.cpuFLOPS; + double memWriteTime = output != null? getMemWriteTime(output, executorMetrics) : 0; + return Math.max(memScanTime, cpuComputationTime) + memWriteTime; + } + + public static void assignOutputRDDStats(SPInstruction inst, VarStats output, VarStats...inputs) { + if (!output.isScalar()) { + SPType instType = inst.getSPInstructionType(); + String opcode = inst.getOpcode(); + if (output.getCells() < 0) { + inferStats(instType, opcode, output, inputs); + } + } + output.rddStats = new RDDStats(output); + } + + private static void inferStats(SPType instType, String opcode, VarStats output, VarStats...inputs) { + switch (instType) { + case Unary: + case Builtin: + CPCostUtils.inferStats(CPType.Unary, opcode, output, inputs); + break; + case AggregateUnary: + case AggregateUnarySketch: + CPCostUtils.inferStats(CPType.AggregateUnary, opcode, output, inputs); + case MatrixIndexing: + CPCostUtils.inferStats(CPType.MatrixIndexing, opcode, output, inputs); + break; + case Reorg: + CPCostUtils.inferStats(CPType.Reorg, opcode, output, inputs); + break; + case Binary: + CPCostUtils.inferStats(CPType.Binary, opcode, output, inputs); + break; + case CPMM: + case RMM: + case MAPMM: + case PMM: + case ZIPMM: + CPCostUtils.inferStats(CPType.AggregateBinary, opcode, output, inputs); + break; + case ParameterizedBuiltin: + CPCostUtils.inferStats(CPType.ParameterizedBuiltin, opcode, output, inputs); + break; + case Rand: + CPCostUtils.inferStats(CPType.Rand, opcode, output, inputs); + break; + case Ctable: + CPCostUtils.inferStats(CPType.Ctable, opcode, output, inputs); + break; + default: + throw new RuntimeException("Operation of type "+instType+" with opcode '"+opcode+"' has no formula for inferring dimensions"); + } + if (output.getCells() < 0) { + throw new RuntimeException("Operation of type "+instType+" with opcode '"+opcode+"' has incomplete formula for inferring dimensions"); + } + } + + private static long getInstNFLOP( + SPType instructionType, + String opcode, + VarStats output, + VarStats...inputs + ) { + opcode = opcode.toLowerCase(); + double costs; + switch (instructionType) { + case Reblock: + if (opcode.startsWith("libsvm")) { + return output.getCellsWithSparsity(); + } else { // starts with "rblk" or "csvrblk" + return output.getCells(); + } + case Unary: + case Builtin: + return CPCostUtils.getInstNFLOP(CPType.Unary, opcode, output, inputs); + case AggregateUnary: + case AggregateUnarySketch: + switch (opcode) { + case "uacdr": + case "uacdc": + throw new DMLRuntimeException(opcode + " opcode is not implemented by SystemDS"); + default: + return CPCostUtils.getInstNFLOP(CPType.AggregateUnary, opcode, output, inputs); + } + case CumsumAggregate: + switch (opcode) { + case "ucumack+": + case "ucumac*": + case "ucumacmin": + case "ucumacmax": + costs = 1; break; + case "ucumac+*": + costs = 2; break; + default: + throw new DMLRuntimeException(opcode + " opcode is not implemented by SystemDS"); + } + return (long) (costs * inputs[0].getCells() + costs * output.getN()); + case TSMM: + case TSMM2: + return CPCostUtils.getInstNFLOP(CPType.MMTSJ, opcode, output, inputs); + case Reorg: + case MatrixReshape: + return CPCostUtils.getInstNFLOP(CPType.Reorg, opcode, output, inputs); + case MatrixIndexing: + // the actual opcode value is not used at the moment + return CPCostUtils.getInstNFLOP(CPType.MatrixIndexing, opcode, output, inputs); + case Cast: + return output.getCellsWithSparsity(); + case QSort: + return CPCostUtils.getInstNFLOP(CPType.QSort, opcode, output, inputs); + case CentralMoment: + return CPCostUtils.getInstNFLOP(CPType.CentralMoment, opcode, output, inputs); + case UaggOuterChain: + case Dnn: + throw new RuntimeException("Spark operation type'" + instructionType + "' is not supported yet"); + // types corresponding to BinaryCPInstruction + case Binary: + switch (opcode) { + case "+*": + case "-*": + // original "map+*" and "map-*" + // "+*" and "-*" defined as ternary + throw new RuntimeException("Spark operation with opcode '" + opcode + "' is not supported yet"); + default: + return CPCostUtils.getInstNFLOP(CPType.Binary, opcode, output, inputs); + } + case CPMM: + case RMM: + case MAPMM: + case PMM: + case ZIPMM: + case PMAPMM: + // do not reduce by factor of 2: not explicit matrix multiplication + return 2 * CPCostUtils.getInstNFLOP(CPType.AggregateBinary, opcode, output, inputs); + case MAPMMCHAIN: + return 2 * inputs[0].getCells() * inputs[0].getN() // ba(+*) + + 2 * inputs[0].getM() * inputs[1].getN() // cellwise b(*) + r(t) + + 2 * inputs[0].getCellsWithSparsity() * inputs[1].getN() // ba(+*) + + inputs[1].getM() * output.getM() ; //r(t) + case BinUaggChain: + break; + case MAppend: + case RAppend: + case GAppend: + case GAlignedAppend: + // the actual opcode value is not used at the moment + return CPCostUtils.getInstNFLOP(CPType.Append, opcode, output, inputs); + case BuiltinNary: + return CPCostUtils.getInstNFLOP(CPType.BuiltinNary, opcode, output, inputs); + case Ctable: + return CPCostUtils.getInstNFLOP(CPType.Ctable, opcode, output, inputs); + case ParameterizedBuiltin: + return CPCostUtils.getInstNFLOP(CPType.ParameterizedBuiltin, opcode, output, inputs); + default: + // all existing cases should have been handled above + throw new DMLRuntimeException("Spark operation type'" + instructionType + "' is not supported by SystemDS"); + } + throw new RuntimeException(); + } + + +// //ternary aggregate operators +// case "tak+*": +// break; +// case "tack+*": +// break; +// // Neural network operators +// case "conv2d": +// case "conv2d_bias_add": +// case "maxpooling": +// case "relu_maxpooling": +// case RightIndex.OPCODE: +// case LeftIndex.OPCODE: +// case "mapLeftIndex": +// case "_map",: +// break; +// // Spark-specific instructions +// case Checkpoint.DEFAULT_CP_OPCODE,: +// break; +// case Checkpoint.ASYNC_CP_OPCODE,: +// break; +// case Compression.OPCODE,: +// break; +// case DeCompression.OPCODE,: +// break; +// // Parameterized Builtin Functions +// case "autoDiff",: +// break; +// case "contains",: +// break; +// case "groupedagg",: +// break; +// case "mapgroupedagg",: +// break; +// case "rmempty",: +// break; +// case "replace",: +// break; +// case "rexpand",: +// break; +// case "lowertri",: +// break; +// case "uppertri",: +// break; +// case "tokenize",: +// break; +// case "transformapply",: +// break; +// case "transformdecode",: +// break; +// case "transformencode",: +// break; +// case "mappend",: +// break; +// case "rappend",: +// break; +// case "gappend",: +// break; +// case "galignedappend",: +// break; +// //ternary instruction opcodes +// case "ctable",: +// break; +// case "ctableexpand",: +// break; // -// //ternary instruction opcodes -// case "+*",: -// break; -// case "-*",: -// break; -// case "ifelse",: -// break; +// //ternary instruction opcodes +// case "+*",: +// break; +// case "-*",: +// break; +// case "ifelse",: +// break; // -// //quaternary instruction opcodes -// case WeightedSquaredLoss.OPCODE,: -// break; -// case WeightedSquaredLossR.OPCODE,: -// break; -// case WeightedSigmoid.OPCODE,: -// break; -// case WeightedSigmoidR.OPCODE,: -// break; -// case WeightedDivMM.OPCODE,: -// break; -// case WeightedDivMMR.OPCODE,: -// break; -// case WeightedCrossEntropy.OPCODE,: -// break; -// case WeightedCrossEntropyR.OPCODE,: -// break; -// case WeightedUnaryMM.OPCODE,: -// break; -// case WeightedUnaryMMR.OPCODE,: -// break; -// case "bcumoffk+": -// break; -// case "bcumoff*": -// break; -// case "bcumoff+*": -// break; -// case "bcumoffmin",: -// break; -// case "bcumoffmax",: -// break; +// //quaternary instruction opcodes +// case WeightedSquaredLoss.OPCODE,: +// break; +// case WeightedSquaredLossR.OPCODE,: +// break; +// case WeightedSigmoid.OPCODE,: +// break; +// case WeightedSigmoidR.OPCODE,: +// break; +// case WeightedDivMM.OPCODE,: +// break; +// case WeightedDivMMR.OPCODE,: +// break; +// case WeightedCrossEntropy.OPCODE,: +// break; +// case WeightedCrossEntropyR.OPCODE,: +// break; +// case WeightedUnaryMM.OPCODE,: +// break; +// case WeightedUnaryMMR.OPCODE,: +// break; +// case "bcumoffk+": +// break; +// case "bcumoff*": +// break; +// case "bcumoff+*": +// break; +// case "bcumoffmin",: +// break; +// case "bcumoffmax",: +// break; // -// //central moment, covariance, quantiles (sort/pick) -// case "cm" ,: -// break; -// case "cov" ,: -// break; -// case "qsort" ,: -// break; -// case "qpick" ,: -// break; +// //central moment, covariance, quantiles (sort/pick) +// case "cm" ,: +// break; +// case "cov" ,: +// break; +// case "qsort" ,: +// break; +// case "qpick" ,: +// break; // -// case "binuaggchain",: -// break; +// case "binuaggchain",: +// break; // -// case "write" ,: -// break; +// case "write" ,: +// break; // // -// case "spoof": -// break; -// default: -// throw RuntimeException("No complexity factor for op. code: " + opcode); -// } +// case "spoof": +// break; +// default: +// throw RuntimeException("No complexity factor for op. code: " + opcode); +// } } diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/RollIndex.java b/src/main/java/org/apache/sysds/runtime/functionobjects/RollIndex.java index 5bd78bb703c..976a172e8b8 100644 --- a/src/main/java/org/apache/sysds/runtime/functionobjects/RollIndex.java +++ b/src/main/java/org/apache/sysds/runtime/functionobjects/RollIndex.java @@ -29,37 +29,37 @@ * in ReorgOperator in order to identify sort operations. */ public class RollIndex extends IndexFunction { - private static final long serialVersionUID = -8446389232078905200L; + private static final long serialVersionUID = -8446389232078905200L; - private final int _shift; + private final int _shift; - public RollIndex(int shift) { - _shift = shift; - } + public RollIndex(int shift) { + _shift = shift; + } - public int getShift() { - return _shift; - } + public int getShift() { + return _shift; + } - @Override - public boolean computeDimension(int row, int col, CellIndex retDim) { - retDim.set(row, col); - return false; - } + @Override + public boolean computeDimension(int row, int col, CellIndex retDim) { + retDim.set(row, col); + return false; + } - @Override - public boolean computeDimension(DataCharacteristics in, DataCharacteristics out) { - out.set(in.getRows(), in.getCols(), in.getBlocksize(), in.getNonZeros()); - return false; - } + @Override + public boolean computeDimension(DataCharacteristics in, DataCharacteristics out) { + out.set(in.getRows(), in.getCols(), in.getBlocksize(), in.getNonZeros()); + return false; + } - @Override - public void execute(MatrixIndexes in, MatrixIndexes out) { - throw new NotImplementedException(); - } + @Override + public void execute(MatrixIndexes in, MatrixIndexes out) { + throw new NotImplementedException(); + } - @Override - public void execute(CellIndex in, CellIndex out) { - throw new NotImplementedException(); - } + @Override + public void execute(CellIndex in, CellIndex out) { + throw new NotImplementedException(); + } } diff --git a/src/test/java/org/apache/sysds/performance/compression/TransformPerf.java b/src/test/java/org/apache/sysds/performance/compression/TransformPerf.java index 7d85a30b794..7c8243125c1 100644 --- a/src/test/java/org/apache/sysds/performance/compression/TransformPerf.java +++ b/src/test/java/org/apache/sysds/performance/compression/TransformPerf.java @@ -75,7 +75,6 @@ private void updateGen() { } } - @SuppressWarnings("unused") private void detectSchema(int k) { FrameBlock fb = gen.take(); long in = fb.getInMemorySize(); @@ -84,7 +83,6 @@ private void detectSchema(int k) { ret.add(new InOut(in, out)); } - @SuppressWarnings("unused") private void detectAndApply(int k) { FrameBlock fb = gen.take(); long in = fb.getInMemorySize(); @@ -94,7 +92,6 @@ private void detectAndApply(int k) { ret.add(new InOut(in, outS)); } - @SuppressWarnings("unused") private void transformEncode(int k) { FrameBlock fb = gen.take(); long in = fb.getInMemorySize(); diff --git a/src/test/java/org/apache/sysds/test/component/resource/CPCostUtilsTest.java b/src/test/java/org/apache/sysds/test/component/resource/CPCostUtilsTest.java index 303f2be4ab8..56624a6371e 100644 --- a/src/test/java/org/apache/sysds/test/component/resource/CPCostUtilsTest.java +++ b/src/test/java/org/apache/sysds/test/component/resource/CPCostUtilsTest.java @@ -29,553 +29,553 @@ public class CPCostUtilsTest { - @Test - public void testUnaryNotInstNFLOP() { - long expectedValue = 1000 * 1000; - testUnaryInstNFLOP("!", -1, -1, expectedValue); - } - - @Test - public void testUnaryIsnaInstNFLOP() { - long expectedValue = 1000 * 1000; - testUnaryInstNFLOP("isna", -1, -1, expectedValue); - } - - @Test - public void testUnaryIsnanInstNFLOP() { - long expectedValue = 1000 * 1000; - testUnaryInstNFLOP("isnan", -1, -1, expectedValue); - } - - @Test - public void testUnaryIsinfInstNFLOP() { - long expectedValue = 1000 * 1000; - testUnaryInstNFLOP("isinf", -1, -1, expectedValue); - } - - @Test - public void testUnaryCeilInstNFLOP() { - long expectedValue = 1000 * 1000; - testUnaryInstNFLOP("ceil", -1, -1, expectedValue); - } - - @Test - public void testUnaryFloorInstNFLOP() { - long expectedValue = 1000 * 1000; - testUnaryInstNFLOP("floor", -1, -1, expectedValue); - } - - @Test - public void testAbsInstNFLOPDefaultSparsity() { - long expectedValue = 1000 * 1000; - testUnaryInstNFLOP("abs", -1, -1, expectedValue); - } - - @Test - public void testAbsInstNFLOPSparse() { - long expectedValue = (long) (0.5 * 1000 * 1000); - testUnaryInstNFLOP("abs", 0.5, 0.5, expectedValue); - } - - @Test - public void testRoundInstNFLOPDefaultSparsity() { - long expectedValue = 1000 * 1000; - testUnaryInstNFLOP("round", -1, -1, expectedValue); - } - - @Test - public void testRoundInstNFLOPSparse() { - long expectedValue = (long) (0.5 * 1000 * 1000); - testUnaryInstNFLOP("round", 0.5, 0.5, expectedValue); - } - - @Test - public void testSignInstNFLOPDefaultSparsity() { - long expectedValue = 1000 * 1000; - testUnaryInstNFLOP("sign", -1, -1, expectedValue); - } - - @Test - public void testSignInstNFLOPSparse() { - long expectedValue = (long) (0.5 * 1000 * 1000); - testUnaryInstNFLOP("sign", 0.5, 0.5, expectedValue); - } - - @Test - public void testSpropInstNFLOPDefaultSparsity() { - long expectedValue = 2 * 1000 * 1000; - testUnaryInstNFLOP("sprop", -1, -1, expectedValue); - } - - @Test - public void testSpropInstNFLOPSparse() { - long expectedValue = (long) (2 * 0.5 * 1000 * 1000); - testUnaryInstNFLOP("sprop", 0.5, 0.5, expectedValue); - } - - @Test - public void testSqrtInstNFLOPDefaultSparsity() { - long expectedValue = 2 * 1000 * 1000; - testUnaryInstNFLOP("sqrt", -1, -1, expectedValue); - } - - @Test - public void testSqrtInstNFLOPSparse() { - long expectedValue = (long) (2 * 0.5 * 1000 * 1000); - testUnaryInstNFLOP("sqrt", 0.5, 0.5, expectedValue); - } - - @Test - public void testExpInstNFLOPDefaultSparsity() { - long expectedValue = 18 * 1000 * 1000; - testUnaryInstNFLOP("exp", -1, -1, expectedValue); - } - - @Test - public void testExpInstNFLOPSparse() { - long expectedValue = (long) (18 * 0.5 * 1000 * 1000); - testUnaryInstNFLOP("exp", 0.5, 0.5, expectedValue); - } - - @Test - public void testSigmoidInstNFLOPDefaultSparsity() { - long expectedValue = 21 * 1000 * 1000; - testUnaryInstNFLOP("sigmoid", -1, -1, expectedValue); - } - - @Test - public void testSigmoidInstNFLOPSparse() { - long expectedValue = (long) (21 * 0.5 * 1000 * 1000); - testUnaryInstNFLOP("sigmoid", 0.5, 0.5, expectedValue); - } - - @Test - public void testPlogpInstNFLOPDefaultSparsity() { - long expectedValue = 32 * 1000 * 1000; - testUnaryInstNFLOP("plogp", -1, -1, expectedValue); - } - - @Test - public void testPlogpInstNFLOPSparse() { - long expectedValue = (long) (32 * 0.5 * 1000 * 1000); - testUnaryInstNFLOP("plogp", 0.5, 0.5, expectedValue); - } - - @Test - public void testPrintInstNFLOP() { - long expectedValue = 1000 * 1000; - testUnaryInstNFLOP("print", -1, -1, expectedValue); - } - - @Test - public void testAssertInstNFLOP() { - long expectedValue = 1000 * 1000; - testUnaryInstNFLOP("assert", -1, -1, expectedValue); - } - - @Test - public void testSinInstNFLOPDefaultSparsity() { - long expectedValue = 18 * 1000 * 1000; - testUnaryInstNFLOP("sin", -1, -1, expectedValue); - } - - @Test - public void testSinInstNFLOPSparse() { - long expectedValue = (long) (18 * 0.5 * 1000 * 1000); - testUnaryInstNFLOP("sin", 0.5, 0.5, expectedValue); - } - - @Test - public void testCosInstNFLOPDefaultSparsity() { - long expectedValue = 22 * 1000 * 1000; - testUnaryInstNFLOP("cos", -1, -1, expectedValue); - } - - @Test - public void testCosInstNFLOPSparse() { - long expectedValue = (long) (22 * 0.5 * 1000 * 1000); - testUnaryInstNFLOP("cos", 0.5, 0.5, expectedValue); - } - - @Test - public void testTanInstNFLOPDefaultSparsity() { - long expectedValue = 42 * 1000 * 1000; - testUnaryInstNFLOP("tan", -1, -1, expectedValue); - } - - @Test - public void testTanInstNFLOPSparse() { - long expectedValue = (long) (42 * 0.5 * 1000 * 1000); - testUnaryInstNFLOP("tan", 0.5, 0.5, expectedValue); - } - - @Test - public void testAsinInstNFLOP() { - long expectedValue = 93 * 1000 * 1000; - testUnaryInstNFLOP("asin", -1, -1, expectedValue); - } - - @Test - public void testSinhInstNFLOP() { - long expectedValue = 93 * 1000 * 1000; - testUnaryInstNFLOP("sinh", -1, -1, expectedValue); - } - - @Test - public void testAcosInstNFLOP() { - long expectedValue = 103 * 1000 * 1000; - testUnaryInstNFLOP("acos", -1, -1, expectedValue); - } - - @Test - public void testCoshInstNFLOP() { - long expectedValue = 103 * 1000 * 1000; - testUnaryInstNFLOP("cosh", -1, -1, expectedValue); - } - - @Test - public void testAtanInstNFLOP() { - long expectedValue = 40 * 1000 * 1000; - testUnaryInstNFLOP("atan", -1, -1, expectedValue); - } - - @Test - public void testTanhInstNFLOP() { - long expectedValue = 40 * 1000 * 1000; - testUnaryInstNFLOP("tanh", -1, -1, expectedValue); - } - - @Test - public void testUcumkPlusInstNFLOPDefaultSparsity() { - long expectedValue = 1000 * 1000; - testUnaryInstNFLOP("ucumk+", -1, -1, expectedValue); - } - - @Test - public void testUcumkPlusInstNFLOPSparse() { - long expectedValue = (long) (0.5 * 1000 * 1000); - testUnaryInstNFLOP("ucumk+", 0.5, 0.5, expectedValue); - } - - @Test - public void testUcumMinInstNFLOPDefaultSparsity() { - long expectedValue = 1000 * 1000; - testUnaryInstNFLOP("ucummin", -1, -1, expectedValue); - } - - @Test - public void testUcumMinInstNFLOPSparse() { - long expectedValue = (long) (0.5 * 1000 * 1000); - testUnaryInstNFLOP("ucummin", 0.5, 0.5, expectedValue); - } - - @Test - public void testUcumMaxInstNFLOPDefaultSparsity() { - long expectedValue = 1000 * 1000; - testUnaryInstNFLOP("ucummax", -1, -1, expectedValue); - } - - @Test - public void testUcumMaxInstNFLOPSparse() { - long expectedValue = (long) (0.5 * 1000 * 1000); - testUnaryInstNFLOP("ucummax", 0.5, 0.5, expectedValue); - } - - @Test - public void testUcumMultInstNFLOPDefaultSparsity() { - long expectedValue = 1000 * 1000; - testUnaryInstNFLOP("ucum*", -1, -1, expectedValue); - } - - @Test - public void testUcumMultInstNFLOPSparse() { - long expectedValue = (long) (0.5 * 1000 * 1000); - testUnaryInstNFLOP("ucum*", 0.5, 0.5, expectedValue); - } - - @Test - public void testUcumkPlusMultInstNFLOPDefaultSparsity() { - long expectedValue = 2 * 1000 * 1000; - testUnaryInstNFLOP("ucumk+*", -1, -1, expectedValue); - } - - @Test - public void testUcumkPlusMultInstNFLOPSparse() { - long expectedValue = (long) (2 * 0.5 * 1000 * 1000); - testUnaryInstNFLOP("ucumk+*", 0.5, 0.5, expectedValue); - } - - @Test - public void testStopInstNFLOP() { - long expectedValue = 0; - testUnaryInstNFLOP("stop", -1, -1, expectedValue); - } - - @Test - public void testTypeofInstNFLOP() { - long expectedValue = 1000 * 1000; - testUnaryInstNFLOP("typeof", -1, -1, expectedValue); - } - - @Test - public void testInverseInstNFLOPDefaultSparsity() { - long expectedValue = (long) ((4.0 / 3.0) * (1000 * 1000) * (1000 * 1000) * (1000 * 1000)); - testUnaryInstNFLOP("inverse", -1, -1, expectedValue); - } - - @Test - public void testInverseInstNFLOPSparse() { - long expectedValue = (long) ((4.0 / 3.0) * (1000 * 1000) * (0.5 * 1000 * 1000) * (0.5 *1000 * 1000)); - testUnaryInstNFLOP("inverse", 0.5, 0.5, expectedValue); - } - - @Test - public void testCholeskyInstNFLOPDefaultSparsity() { - long expectedValue = (long) ((1.0 / 3.0) * (1000 * 1000) * (1000 * 1000) * (1000 * 1000)); - testUnaryInstNFLOP("cholesky", -1, -1, expectedValue); - } - - @Test - public void testCholeskyInstNFLOPSparse() { - long expectedValue = (long) ((1.0 / 3.0) * (1000 * 1000) * (0.5 * 1000 * 1000) * (0.5 *1000 * 1000)); - testUnaryInstNFLOP("cholesky", 0.5, 0.5, expectedValue); - } - - @Test - public void testLogInstNFLOP() { - long expectedValue = 32 * 1000 * 1000; - testBuiltinInstNFLOP("log", -1, expectedValue); - } - - @Test - public void testLogNzInstNFLOPDefaultSparsity() { - long expectedValue = 32 * 1000 * 1000; - testBuiltinInstNFLOP("log_nz", -1, expectedValue); - } - - @Test - public void testLogNzInstNFLOPSparse() { - long expectedValue = (long) (32 * 0.5 * 1000 * 1000); - testBuiltinInstNFLOP("log_nz", 0.5, expectedValue); - } - - @Test - public void testNrowInstNFLOP() { - long expectedValue = 10L; - testAggregateUnaryInstNFLOP("nrow", expectedValue); - } - - @Test - public void testNcolInstNFLOP() { - long expectedValue = 10L; - testAggregateUnaryInstNFLOP("ncol", expectedValue); - } - - @Test - public void testLengthInstNFLOP() { - long expectedValue = 10L; - testAggregateUnaryInstNFLOP("length", expectedValue); - } - - @Test - public void testExistsInstNFLOP() { - long expectedValue = 10L; - testAggregateUnaryInstNFLOP("exists", expectedValue); - } - - @Test - public void testLineageInstNFLOP() { - long expectedValue = 10L; - testAggregateUnaryInstNFLOP("lineage", expectedValue); - } - - @Test - public void testUakInstNFLOP() { - long expectedValue = 4 * 1000 * 1000; - testAggregateUnaryInstNFLOP("uak+", expectedValue); - } - - @Test - public void testUarkInstNFLOP() { - long expectedValue = 4L * 2000 * 2000; - testAggregateUnaryRowInstNFLOP("uark+", -1, expectedValue); - testAggregateUnaryRowInstNFLOP("uark+", 0.5, expectedValue); - } - - @Test - public void testUackInstNFLOP() { - long expectedValue = 4L * 3000 * 3000; - testAggregateUnaryColInstNFLOP("uack+", -1, expectedValue); - testAggregateUnaryColInstNFLOP("uack+", 0.5, expectedValue); - } - - @Test - public void testUasqkInstNFLOP() { - long expectedValue = 5L * 1000 * 1000; - testAggregateUnaryInstNFLOP("uasqk+", expectedValue); - } - - @Test - public void testUarsqkInstNFLOP() { - long expectedValue = 5L * 2000 * 2000; - testAggregateUnaryRowInstNFLOP("uarsqk+", -1, expectedValue); - testAggregateUnaryRowInstNFLOP("uarsqk+", 0.5, expectedValue); - } - - @Test - public void testUacsqkInstNFLOP() { - long expectedValue = 5L * 3000 * 3000; - testAggregateUnaryColInstNFLOP("uacsqk+", -1, expectedValue); - testAggregateUnaryColInstNFLOP("uacsqk+", 0.5, expectedValue); - } - - @Test - public void testUameanInstNFLOP() { - long expectedValue = 7L * 1000 * 1000; - testAggregateUnaryInstNFLOP("uamean", expectedValue); - } - - @Test - public void testUarmeanInstNFLOP() { - long expectedValue = 7L * 2000 * 2000; - testAggregateUnaryRowInstNFLOP("uarmean", -1, expectedValue); - testAggregateUnaryRowInstNFLOP("uarmean", 0.5, expectedValue); - } - - @Test - public void testUacmeanInstNFLOP() { - long expectedValue = 7L * 3000 * 3000; - testAggregateUnaryColInstNFLOP("uacmean", -1, expectedValue); - testAggregateUnaryColInstNFLOP("uacmean", 0.5, expectedValue); - } - - @Test - public void testUavarInstNFLOP() { - long expectedValue = 14L * 1000 * 1000; - testAggregateUnaryInstNFLOP("uavar", expectedValue); - } - - @Test - public void testUarvarInstNFLOP() { - long expectedValue = 14L * 2000 * 2000; - testAggregateUnaryRowInstNFLOP("uarvar", -1, expectedValue); - testAggregateUnaryRowInstNFLOP("uarvar", 0.5, expectedValue); - } - - @Test - public void testUacvarInstNFLOP() { - long expectedValue = 14L * 3000 * 3000; - testAggregateUnaryColInstNFLOP("uacvar", -1, expectedValue); - testAggregateUnaryColInstNFLOP("uacvar", 0.5, expectedValue); - } - - @Test - public void testUamaxInstNFLOP() { - long expectedValue = 1000 * 1000; - testAggregateUnaryInstNFLOP("uamax", expectedValue); - } - - @Test - public void testUarmaxInstNFLOP() { - long expectedValue = 2000 * 2000; - testAggregateUnaryRowInstNFLOP("uarmax", -1, expectedValue); - testAggregateUnaryRowInstNFLOP("uarmax", 0.5, expectedValue); - } - - @Test - public void testUarimaxInstNFLOP() { - long expectedValue = 2000 * 2000; - testAggregateUnaryRowInstNFLOP("uarimax", -1, expectedValue); - testAggregateUnaryRowInstNFLOP("uarimax", 0.5, expectedValue); - } - - @Test - public void testUacmaxInstNFLOP() { - long expectedValue = 3000 * 3000; - testAggregateUnaryColInstNFLOP("uacmax", -1, expectedValue); - testAggregateUnaryColInstNFLOP("uacmax", 0.5, expectedValue); - } - - @Test - public void testUaminInstNFLOP() { - long expectedValue = 1000 * 1000; - testAggregateUnaryInstNFLOP("uamin", expectedValue); - } - - @Test - public void testUarminInstNFLOP() { - long expectedValue = 2000 * 2000; - testAggregateUnaryRowInstNFLOP("uarmin", -1, expectedValue); - testAggregateUnaryRowInstNFLOP("uarmin", 0.5, expectedValue); - } - - @Test - public void testUariminInstNFLOP() { - long expectedValue = 2000 * 2000; - testAggregateUnaryRowInstNFLOP("uarimin", -1, expectedValue); - testAggregateUnaryRowInstNFLOP("uarimin", 0.5, expectedValue); - } - - @Test - public void testUacminInstNFLOP() { - long expectedValue = 3000 * 3000; - testAggregateUnaryColInstNFLOP("uacmin", -1, expectedValue); - testAggregateUnaryColInstNFLOP("uacmin", 0.5, expectedValue); - } - - // HELPERS - - private void testUnaryInstNFLOP(String opcode, double sparsityIn, double sparsityOut, long expectedNFLOP) { - long nnzIn = sparsityIn < 0? -1 : (long) (sparsityIn * 1000 * 1000); - VarStats input = generateVarStatsMatrix("_mVar1", 1000, 1000, nnzIn); - long nnzOut = sparsityOut < 0? -1 : (long) (sparsityOut * 1000 * 1000); - VarStats output = generateVarStatsMatrix("_mVar2", 1000, 1000, nnzOut); - - long result = CPCostUtils.getInstNFLOP(CPType.Unary, opcode, output, input); - assertEquals(expectedNFLOP, result); - } - - private void testBuiltinInstNFLOP(String opcode, double sparsityIn, long expectedNFLOP) { - long nnz = sparsityIn < 0? -1 : (long) (sparsityIn * 1000 * 1000); - VarStats input = generateVarStatsMatrix("_mVar1", 1000, 1000, nnz); - VarStats output = generateVarStatsMatrix("_mVar2", 1000, 1000, -1); - - long result = CPCostUtils.getInstNFLOP(CPType.Unary, opcode, output, input); - assertEquals(expectedNFLOP, result); - } - - private void testAggregateUnaryInstNFLOP(String opcode, long expectedNFLOP) { - VarStats input = generateVarStatsMatrix("_mVar1", 1000, 1000, -1); - VarStats output = generateVarStatsScalarLiteral("_Var2"); - - long result = CPCostUtils.getInstNFLOP(CPType.AggregateUnary, opcode, output, input); - assertEquals(expectedNFLOP, result); - } - - private void testAggregateUnaryRowInstNFLOP(String opcode, double sparsityOut, long expectedNFLOP) { - VarStats input = generateVarStatsMatrix("_mVar1", 2000, 1000, -1); - long nnzOut = sparsityOut < 0? -1 : (long) (sparsityOut * 2000); - VarStats output = generateVarStatsMatrix("_mVar2", 2000, 1, nnzOut); - - long result = CPCostUtils.getInstNFLOP(CPType.AggregateUnary, opcode, output, input); - assertEquals(expectedNFLOP, result); - } - - private void testAggregateUnaryColInstNFLOP(String opcode, double sparsityOut, long expectedNFLOP) { - VarStats input = generateVarStatsMatrix("_mVar1", 1000, 3000, -1); - long nnzOut = sparsityOut < 0? -1 : (long) (sparsityOut * 3000); - VarStats output = generateVarStatsMatrix("_mVar2", 1, 3000, nnzOut); - - long result = CPCostUtils.getInstNFLOP(CPType.AggregateUnary, opcode, output, input); - assertEquals(expectedNFLOP, result); - } - - private VarStats generateVarStatsMatrix(String name, long rows, long cols, long nnz) { - MatrixCharacteristics mc = new MatrixCharacteristics(rows, cols, nnz); - return new VarStats(name, mc); - } - - private VarStats generateVarStatsScalarLiteral(String nameOrValue) { - return new VarStats(nameOrValue, null); - } + @Test + public void testUnaryNotInstNFLOP() { + long expectedValue = 1000 * 1000; + testUnaryInstNFLOP("!", -1, -1, expectedValue); + } + + @Test + public void testUnaryIsnaInstNFLOP() { + long expectedValue = 1000 * 1000; + testUnaryInstNFLOP("isna", -1, -1, expectedValue); + } + + @Test + public void testUnaryIsnanInstNFLOP() { + long expectedValue = 1000 * 1000; + testUnaryInstNFLOP("isnan", -1, -1, expectedValue); + } + + @Test + public void testUnaryIsinfInstNFLOP() { + long expectedValue = 1000 * 1000; + testUnaryInstNFLOP("isinf", -1, -1, expectedValue); + } + + @Test + public void testUnaryCeilInstNFLOP() { + long expectedValue = 1000 * 1000; + testUnaryInstNFLOP("ceil", -1, -1, expectedValue); + } + + @Test + public void testUnaryFloorInstNFLOP() { + long expectedValue = 1000 * 1000; + testUnaryInstNFLOP("floor", -1, -1, expectedValue); + } + + @Test + public void testAbsInstNFLOPDefaultSparsity() { + long expectedValue = 1000 * 1000; + testUnaryInstNFLOP("abs", -1, -1, expectedValue); + } + + @Test + public void testAbsInstNFLOPSparse() { + long expectedValue = (long) (0.5 * 1000 * 1000); + testUnaryInstNFLOP("abs", 0.5, 0.5, expectedValue); + } + + @Test + public void testRoundInstNFLOPDefaultSparsity() { + long expectedValue = 1000 * 1000; + testUnaryInstNFLOP("round", -1, -1, expectedValue); + } + + @Test + public void testRoundInstNFLOPSparse() { + long expectedValue = (long) (0.5 * 1000 * 1000); + testUnaryInstNFLOP("round", 0.5, 0.5, expectedValue); + } + + @Test + public void testSignInstNFLOPDefaultSparsity() { + long expectedValue = 1000 * 1000; + testUnaryInstNFLOP("sign", -1, -1, expectedValue); + } + + @Test + public void testSignInstNFLOPSparse() { + long expectedValue = (long) (0.5 * 1000 * 1000); + testUnaryInstNFLOP("sign", 0.5, 0.5, expectedValue); + } + + @Test + public void testSpropInstNFLOPDefaultSparsity() { + long expectedValue = 2 * 1000 * 1000; + testUnaryInstNFLOP("sprop", -1, -1, expectedValue); + } + + @Test + public void testSpropInstNFLOPSparse() { + long expectedValue = (long) (2 * 0.5 * 1000 * 1000); + testUnaryInstNFLOP("sprop", 0.5, 0.5, expectedValue); + } + + @Test + public void testSqrtInstNFLOPDefaultSparsity() { + long expectedValue = 2 * 1000 * 1000; + testUnaryInstNFLOP("sqrt", -1, -1, expectedValue); + } + + @Test + public void testSqrtInstNFLOPSparse() { + long expectedValue = (long) (2 * 0.5 * 1000 * 1000); + testUnaryInstNFLOP("sqrt", 0.5, 0.5, expectedValue); + } + + @Test + public void testExpInstNFLOPDefaultSparsity() { + long expectedValue = 18 * 1000 * 1000; + testUnaryInstNFLOP("exp", -1, -1, expectedValue); + } + + @Test + public void testExpInstNFLOPSparse() { + long expectedValue = (long) (18 * 0.5 * 1000 * 1000); + testUnaryInstNFLOP("exp", 0.5, 0.5, expectedValue); + } + + @Test + public void testSigmoidInstNFLOPDefaultSparsity() { + long expectedValue = 21 * 1000 * 1000; + testUnaryInstNFLOP("sigmoid", -1, -1, expectedValue); + } + + @Test + public void testSigmoidInstNFLOPSparse() { + long expectedValue = (long) (21 * 0.5 * 1000 * 1000); + testUnaryInstNFLOP("sigmoid", 0.5, 0.5, expectedValue); + } + + @Test + public void testPlogpInstNFLOPDefaultSparsity() { + long expectedValue = 32 * 1000 * 1000; + testUnaryInstNFLOP("plogp", -1, -1, expectedValue); + } + + @Test + public void testPlogpInstNFLOPSparse() { + long expectedValue = (long) (32 * 0.5 * 1000 * 1000); + testUnaryInstNFLOP("plogp", 0.5, 0.5, expectedValue); + } + + @Test + public void testPrintInstNFLOP() { + long expectedValue = 1000 * 1000; + testUnaryInstNFLOP("print", -1, -1, expectedValue); + } + + @Test + public void testAssertInstNFLOP() { + long expectedValue = 1000 * 1000; + testUnaryInstNFLOP("assert", -1, -1, expectedValue); + } + + @Test + public void testSinInstNFLOPDefaultSparsity() { + long expectedValue = 18 * 1000 * 1000; + testUnaryInstNFLOP("sin", -1, -1, expectedValue); + } + + @Test + public void testSinInstNFLOPSparse() { + long expectedValue = (long) (18 * 0.5 * 1000 * 1000); + testUnaryInstNFLOP("sin", 0.5, 0.5, expectedValue); + } + + @Test + public void testCosInstNFLOPDefaultSparsity() { + long expectedValue = 22 * 1000 * 1000; + testUnaryInstNFLOP("cos", -1, -1, expectedValue); + } + + @Test + public void testCosInstNFLOPSparse() { + long expectedValue = (long) (22 * 0.5 * 1000 * 1000); + testUnaryInstNFLOP("cos", 0.5, 0.5, expectedValue); + } + + @Test + public void testTanInstNFLOPDefaultSparsity() { + long expectedValue = 42 * 1000 * 1000; + testUnaryInstNFLOP("tan", -1, -1, expectedValue); + } + + @Test + public void testTanInstNFLOPSparse() { + long expectedValue = (long) (42 * 0.5 * 1000 * 1000); + testUnaryInstNFLOP("tan", 0.5, 0.5, expectedValue); + } + + @Test + public void testAsinInstNFLOP() { + long expectedValue = 93 * 1000 * 1000; + testUnaryInstNFLOP("asin", -1, -1, expectedValue); + } + + @Test + public void testSinhInstNFLOP() { + long expectedValue = 93 * 1000 * 1000; + testUnaryInstNFLOP("sinh", -1, -1, expectedValue); + } + + @Test + public void testAcosInstNFLOP() { + long expectedValue = 103 * 1000 * 1000; + testUnaryInstNFLOP("acos", -1, -1, expectedValue); + } + + @Test + public void testCoshInstNFLOP() { + long expectedValue = 103 * 1000 * 1000; + testUnaryInstNFLOP("cosh", -1, -1, expectedValue); + } + + @Test + public void testAtanInstNFLOP() { + long expectedValue = 40 * 1000 * 1000; + testUnaryInstNFLOP("atan", -1, -1, expectedValue); + } + + @Test + public void testTanhInstNFLOP() { + long expectedValue = 40 * 1000 * 1000; + testUnaryInstNFLOP("tanh", -1, -1, expectedValue); + } + + @Test + public void testUcumkPlusInstNFLOPDefaultSparsity() { + long expectedValue = 1000 * 1000; + testUnaryInstNFLOP("ucumk+", -1, -1, expectedValue); + } + + @Test + public void testUcumkPlusInstNFLOPSparse() { + long expectedValue = (long) (0.5 * 1000 * 1000); + testUnaryInstNFLOP("ucumk+", 0.5, 0.5, expectedValue); + } + + @Test + public void testUcumMinInstNFLOPDefaultSparsity() { + long expectedValue = 1000 * 1000; + testUnaryInstNFLOP("ucummin", -1, -1, expectedValue); + } + + @Test + public void testUcumMinInstNFLOPSparse() { + long expectedValue = (long) (0.5 * 1000 * 1000); + testUnaryInstNFLOP("ucummin", 0.5, 0.5, expectedValue); + } + + @Test + public void testUcumMaxInstNFLOPDefaultSparsity() { + long expectedValue = 1000 * 1000; + testUnaryInstNFLOP("ucummax", -1, -1, expectedValue); + } + + @Test + public void testUcumMaxInstNFLOPSparse() { + long expectedValue = (long) (0.5 * 1000 * 1000); + testUnaryInstNFLOP("ucummax", 0.5, 0.5, expectedValue); + } + + @Test + public void testUcumMultInstNFLOPDefaultSparsity() { + long expectedValue = 1000 * 1000; + testUnaryInstNFLOP("ucum*", -1, -1, expectedValue); + } + + @Test + public void testUcumMultInstNFLOPSparse() { + long expectedValue = (long) (0.5 * 1000 * 1000); + testUnaryInstNFLOP("ucum*", 0.5, 0.5, expectedValue); + } + + @Test + public void testUcumkPlusMultInstNFLOPDefaultSparsity() { + long expectedValue = 2 * 1000 * 1000; + testUnaryInstNFLOP("ucumk+*", -1, -1, expectedValue); + } + + @Test + public void testUcumkPlusMultInstNFLOPSparse() { + long expectedValue = (long) (2 * 0.5 * 1000 * 1000); + testUnaryInstNFLOP("ucumk+*", 0.5, 0.5, expectedValue); + } + + @Test + public void testStopInstNFLOP() { + long expectedValue = 0; + testUnaryInstNFLOP("stop", -1, -1, expectedValue); + } + + @Test + public void testTypeofInstNFLOP() { + long expectedValue = 1000 * 1000; + testUnaryInstNFLOP("typeof", -1, -1, expectedValue); + } + + @Test + public void testInverseInstNFLOPDefaultSparsity() { + long expectedValue = (long) ((4.0 / 3.0) * (1000 * 1000) * (1000 * 1000) * (1000 * 1000)); + testUnaryInstNFLOP("inverse", -1, -1, expectedValue); + } + + @Test + public void testInverseInstNFLOPSparse() { + long expectedValue = (long) ((4.0 / 3.0) * (1000 * 1000) * (0.5 * 1000 * 1000) * (0.5 *1000 * 1000)); + testUnaryInstNFLOP("inverse", 0.5, 0.5, expectedValue); + } + + @Test + public void testCholeskyInstNFLOPDefaultSparsity() { + long expectedValue = (long) ((1.0 / 3.0) * (1000 * 1000) * (1000 * 1000) * (1000 * 1000)); + testUnaryInstNFLOP("cholesky", -1, -1, expectedValue); + } + + @Test + public void testCholeskyInstNFLOPSparse() { + long expectedValue = (long) ((1.0 / 3.0) * (1000 * 1000) * (0.5 * 1000 * 1000) * (0.5 *1000 * 1000)); + testUnaryInstNFLOP("cholesky", 0.5, 0.5, expectedValue); + } + + @Test + public void testLogInstNFLOP() { + long expectedValue = 32 * 1000 * 1000; + testBuiltinInstNFLOP("log", -1, expectedValue); + } + + @Test + public void testLogNzInstNFLOPDefaultSparsity() { + long expectedValue = 32 * 1000 * 1000; + testBuiltinInstNFLOP("log_nz", -1, expectedValue); + } + + @Test + public void testLogNzInstNFLOPSparse() { + long expectedValue = (long) (32 * 0.5 * 1000 * 1000); + testBuiltinInstNFLOP("log_nz", 0.5, expectedValue); + } + + @Test + public void testNrowInstNFLOP() { + long expectedValue = 10L; + testAggregateUnaryInstNFLOP("nrow", expectedValue); + } + + @Test + public void testNcolInstNFLOP() { + long expectedValue = 10L; + testAggregateUnaryInstNFLOP("ncol", expectedValue); + } + + @Test + public void testLengthInstNFLOP() { + long expectedValue = 10L; + testAggregateUnaryInstNFLOP("length", expectedValue); + } + + @Test + public void testExistsInstNFLOP() { + long expectedValue = 10L; + testAggregateUnaryInstNFLOP("exists", expectedValue); + } + + @Test + public void testLineageInstNFLOP() { + long expectedValue = 10L; + testAggregateUnaryInstNFLOP("lineage", expectedValue); + } + + @Test + public void testUakInstNFLOP() { + long expectedValue = 4 * 1000 * 1000; + testAggregateUnaryInstNFLOP("uak+", expectedValue); + } + + @Test + public void testUarkInstNFLOP() { + long expectedValue = 4L * 2000 * 2000; + testAggregateUnaryRowInstNFLOP("uark+", -1, expectedValue); + testAggregateUnaryRowInstNFLOP("uark+", 0.5, expectedValue); + } + + @Test + public void testUackInstNFLOP() { + long expectedValue = 4L * 3000 * 3000; + testAggregateUnaryColInstNFLOP("uack+", -1, expectedValue); + testAggregateUnaryColInstNFLOP("uack+", 0.5, expectedValue); + } + + @Test + public void testUasqkInstNFLOP() { + long expectedValue = 5L * 1000 * 1000; + testAggregateUnaryInstNFLOP("uasqk+", expectedValue); + } + + @Test + public void testUarsqkInstNFLOP() { + long expectedValue = 5L * 2000 * 2000; + testAggregateUnaryRowInstNFLOP("uarsqk+", -1, expectedValue); + testAggregateUnaryRowInstNFLOP("uarsqk+", 0.5, expectedValue); + } + + @Test + public void testUacsqkInstNFLOP() { + long expectedValue = 5L * 3000 * 3000; + testAggregateUnaryColInstNFLOP("uacsqk+", -1, expectedValue); + testAggregateUnaryColInstNFLOP("uacsqk+", 0.5, expectedValue); + } + + @Test + public void testUameanInstNFLOP() { + long expectedValue = 7L * 1000 * 1000; + testAggregateUnaryInstNFLOP("uamean", expectedValue); + } + + @Test + public void testUarmeanInstNFLOP() { + long expectedValue = 7L * 2000 * 2000; + testAggregateUnaryRowInstNFLOP("uarmean", -1, expectedValue); + testAggregateUnaryRowInstNFLOP("uarmean", 0.5, expectedValue); + } + + @Test + public void testUacmeanInstNFLOP() { + long expectedValue = 7L * 3000 * 3000; + testAggregateUnaryColInstNFLOP("uacmean", -1, expectedValue); + testAggregateUnaryColInstNFLOP("uacmean", 0.5, expectedValue); + } + + @Test + public void testUavarInstNFLOP() { + long expectedValue = 14L * 1000 * 1000; + testAggregateUnaryInstNFLOP("uavar", expectedValue); + } + + @Test + public void testUarvarInstNFLOP() { + long expectedValue = 14L * 2000 * 2000; + testAggregateUnaryRowInstNFLOP("uarvar", -1, expectedValue); + testAggregateUnaryRowInstNFLOP("uarvar", 0.5, expectedValue); + } + + @Test + public void testUacvarInstNFLOP() { + long expectedValue = 14L * 3000 * 3000; + testAggregateUnaryColInstNFLOP("uacvar", -1, expectedValue); + testAggregateUnaryColInstNFLOP("uacvar", 0.5, expectedValue); + } + + @Test + public void testUamaxInstNFLOP() { + long expectedValue = 1000 * 1000; + testAggregateUnaryInstNFLOP("uamax", expectedValue); + } + + @Test + public void testUarmaxInstNFLOP() { + long expectedValue = 2000 * 2000; + testAggregateUnaryRowInstNFLOP("uarmax", -1, expectedValue); + testAggregateUnaryRowInstNFLOP("uarmax", 0.5, expectedValue); + } + + @Test + public void testUarimaxInstNFLOP() { + long expectedValue = 2000 * 2000; + testAggregateUnaryRowInstNFLOP("uarimax", -1, expectedValue); + testAggregateUnaryRowInstNFLOP("uarimax", 0.5, expectedValue); + } + + @Test + public void testUacmaxInstNFLOP() { + long expectedValue = 3000 * 3000; + testAggregateUnaryColInstNFLOP("uacmax", -1, expectedValue); + testAggregateUnaryColInstNFLOP("uacmax", 0.5, expectedValue); + } + + @Test + public void testUaminInstNFLOP() { + long expectedValue = 1000 * 1000; + testAggregateUnaryInstNFLOP("uamin", expectedValue); + } + + @Test + public void testUarminInstNFLOP() { + long expectedValue = 2000 * 2000; + testAggregateUnaryRowInstNFLOP("uarmin", -1, expectedValue); + testAggregateUnaryRowInstNFLOP("uarmin", 0.5, expectedValue); + } + + @Test + public void testUariminInstNFLOP() { + long expectedValue = 2000 * 2000; + testAggregateUnaryRowInstNFLOP("uarimin", -1, expectedValue); + testAggregateUnaryRowInstNFLOP("uarimin", 0.5, expectedValue); + } + + @Test + public void testUacminInstNFLOP() { + long expectedValue = 3000 * 3000; + testAggregateUnaryColInstNFLOP("uacmin", -1, expectedValue); + testAggregateUnaryColInstNFLOP("uacmin", 0.5, expectedValue); + } + + // HELPERS + + private void testUnaryInstNFLOP(String opcode, double sparsityIn, double sparsityOut, long expectedNFLOP) { + long nnzIn = sparsityIn < 0? -1 : (long) (sparsityIn * 1000 * 1000); + VarStats input = generateVarStatsMatrix("_mVar1", 1000, 1000, nnzIn); + long nnzOut = sparsityOut < 0? -1 : (long) (sparsityOut * 1000 * 1000); + VarStats output = generateVarStatsMatrix("_mVar2", 1000, 1000, nnzOut); + + long result = CPCostUtils.getInstNFLOP(CPType.Unary, opcode, output, input); + assertEquals(expectedNFLOP, result); + } + + private void testBuiltinInstNFLOP(String opcode, double sparsityIn, long expectedNFLOP) { + long nnz = sparsityIn < 0? -1 : (long) (sparsityIn * 1000 * 1000); + VarStats input = generateVarStatsMatrix("_mVar1", 1000, 1000, nnz); + VarStats output = generateVarStatsMatrix("_mVar2", 1000, 1000, -1); + + long result = CPCostUtils.getInstNFLOP(CPType.Unary, opcode, output, input); + assertEquals(expectedNFLOP, result); + } + + private void testAggregateUnaryInstNFLOP(String opcode, long expectedNFLOP) { + VarStats input = generateVarStatsMatrix("_mVar1", 1000, 1000, -1); + VarStats output = generateVarStatsScalarLiteral("_Var2"); + + long result = CPCostUtils.getInstNFLOP(CPType.AggregateUnary, opcode, output, input); + assertEquals(expectedNFLOP, result); + } + + private void testAggregateUnaryRowInstNFLOP(String opcode, double sparsityOut, long expectedNFLOP) { + VarStats input = generateVarStatsMatrix("_mVar1", 2000, 1000, -1); + long nnzOut = sparsityOut < 0? -1 : (long) (sparsityOut * 2000); + VarStats output = generateVarStatsMatrix("_mVar2", 2000, 1, nnzOut); + + long result = CPCostUtils.getInstNFLOP(CPType.AggregateUnary, opcode, output, input); + assertEquals(expectedNFLOP, result); + } + + private void testAggregateUnaryColInstNFLOP(String opcode, double sparsityOut, long expectedNFLOP) { + VarStats input = generateVarStatsMatrix("_mVar1", 1000, 3000, -1); + long nnzOut = sparsityOut < 0? -1 : (long) (sparsityOut * 3000); + VarStats output = generateVarStatsMatrix("_mVar2", 1, 3000, nnzOut); + + long result = CPCostUtils.getInstNFLOP(CPType.AggregateUnary, opcode, output, input); + assertEquals(expectedNFLOP, result); + } + + private VarStats generateVarStatsMatrix(String name, long rows, long cols, long nnz) { + MatrixCharacteristics mc = new MatrixCharacteristics(rows, cols, nnz); + return new VarStats(name, mc); + } + + private VarStats generateVarStatsScalarLiteral(String nameOrValue) { + return new VarStats(nameOrValue, null); + } } diff --git a/src/test/java/org/apache/sysds/test/component/resource/InstructionsCostEstimatorTest.java b/src/test/java/org/apache/sysds/test/component/resource/InstructionsCostEstimatorTest.java index fe2bae90bf0..452e82e1442 100644 --- a/src/test/java/org/apache/sysds/test/component/resource/InstructionsCostEstimatorTest.java +++ b/src/test/java/org/apache/sysds/test/component/resource/InstructionsCostEstimatorTest.java @@ -43,170 +43,169 @@ import static org.apache.sysds.test.component.resource.TestingUtils.getSimpleCloudInstanceMap; public class InstructionsCostEstimatorTest { - private static final HashMap instanceMap = getSimpleCloudInstanceMap(); - - private CostEstimator estimator; - - @Before - public void setup() { - ResourceCompiler.setDriverConfigurations(GBtoBytes(8), 4); - ResourceCompiler.setExecutorConfigurations(4, GBtoBytes(8), 4); - estimator = new CostEstimator(new Program(), instanceMap.get("m5.xlarge"), instanceMap.get("m5.xlarge")); - } - - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // Tests for CP Instructions // - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - - @Test - public void createvarMatrixVariableCPInstructionTest() throws CostEstimationException { - String instDefinition = "CP°createvar°testVar°testOutputFile°false°MATRIX°binary°100°100°1000°10000°COPY"; - VariableCPInstruction inst = VariableCPInstruction.parseInstruction(instDefinition); - testGettingTimeEstimateForCPInst(estimator, null, inst, 0); - // test the proper maintainCPInstVariableStatistics functionality - estimator.maintainStats(inst); - VarStats actualStats = estimator.getStats("testVar"); - Assert.assertNotNull(actualStats); - Assert.assertEquals(10000, actualStats.getCells()); - } - - @Test - public void createvarFrameVariableCPInstructionTest() throws CostEstimationException { - String instDefinition = "CP°createvar°testVar°testOutputFile°false°FRAME°binary°100°100°1000°10000°COPY"; - VariableCPInstruction inst = VariableCPInstruction.parseInstruction(instDefinition); - testGettingTimeEstimateForCPInst(estimator, null, inst, 0); - // test the proper maintainCPInstVariableStatistics functionality - estimator.maintainStats(inst); - VarStats actualStats = estimator.getStats("testVar"); - Assert.assertNotNull(actualStats); - Assert.assertEquals(10000, actualStats.getCells()); - } - - @Test - public void createvarInvalidVariableCPInstructionTest() throws CostEstimationException { - String instDefinition = "CP°createvar°testVar°testOutputFile°false°TENSOR°binary°100°100°1000°10000°copy"; - VariableCPInstruction inst = VariableCPInstruction.parseInstruction(instDefinition); - try { - estimator.maintainStats(inst); - testGettingTimeEstimateForCPInst(estimator, null, inst, 0); - Assert.fail("Tensor is not supported by the cost estimator"); - } catch (RuntimeException e) { - // needed catch block to assert that RuntimeException has been thrown - } - } - - @Test - public void randCPInstructionTest() throws CostEstimationException { - HashMap inputStats = new HashMap<>(); - inputStats.put("matrixVar", generateStats("matrixVar", 10000, 10000, -1)); - inputStats.put("outputVar", generateStats("outputVar", 10000, 10000, -1)); - - String instDefinition = "CP°+°scalarVar·SCALAR·FP64·false°matrixVar·MATRIX·FP64°outputVar·MATRIX·FP64"; - BinaryCPInstruction inst = BinaryCPInstruction.parseInstruction(instDefinition); - testGettingTimeEstimateForCPInst(estimator, inputStats, inst, -1); - } - - @Test - public void randCPInstructionExceedMemoryBudgetTest() { - HashMap inputStats = new HashMap<>(); - inputStats.put("matrixVar", generateStats("matrixVar", 1000000, 1000000, -1)); - inputStats.put("outputVar", generateStats("outputVar", 1000000, 1000000, -1)); - - String instDefinition = "CP°+°scalarVar·SCALAR·FP64·false°matrixVar·MATRIX·FP64°outputVar·MATRIX·FP64"; - BinaryCPInstruction inst = BinaryCPInstruction.parseInstruction(instDefinition); - try { - testGettingTimeEstimateForCPInst(estimator, inputStats, inst, -1); - Assert.fail("CostEstimationException should have been thrown for the given data size and instruction"); - } catch (CostEstimationException e) { - // needed catch block to assert that CostEstimationException has been thrown - } - } - - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // Tests for Spark Instructions // - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - - @Test - public void plusBinaryMatrixMatrixSpInstructionTest() throws CostEstimationException { - HashMap inputStats = new HashMap<>(); - inputStats.put("matrixVar", generateStatsWithRdd("matrixVar", 1000000,1000000, 500000000000L)); - inputStats.put("outputVar", generateStats("outputVar", 1000000,1000000, -1)); - - String instDefinition = "SPARK°+°scalarVar·SCALAR·FP64·false°matrixVar·MATRIX·FP64°outputVar·MATRIX·FP64"; - BinarySPInstruction inst = BinarySPInstruction.parseInstruction(instDefinition); - testGettingTimeEstimateForSparkInst(estimator, inputStats, inst, "outputVar", -1); - } - - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // Helper methods for testing Instructions // - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - - private VarStats generateStats(String name, long m, long n, long nnz) { - MatrixCharacteristics mc = new MatrixCharacteristics(m, n, nnz); - VarStats ret = new VarStats(name, mc); - long size = OptimizerUtils.estimateSizeExactSparsity(ret.getM(), ret.getN(), ret.getSparsity()); - ret.setAllocatedMemory(size); - return ret; - } - - private VarStats generateStatsWithRdd(String name, long m, long n, long nnz) { - MatrixCharacteristics mc = new MatrixCharacteristics(m, n, nnz); - VarStats stats = new VarStats(name, mc); - RDDStats rddStats = new RDDStats(stats); - stats.setRddStats(rddStats); - return stats; - } - - private static void testGettingTimeEstimateForCPInst( - CostEstimator estimator, - HashMap inputStats, - CPInstruction targetInstruction, - double expectedCost - ) throws CostEstimationException { - if (inputStats != null) - estimator.putStats(inputStats); - double actualCost = estimator.getTimeEstimateInst(targetInstruction); - - if (expectedCost < 0) { - // check error-free cost estimation and meaningful result - Assert.assertTrue(actualCost > 0); - } else { - // check error-free cost estimation and exact result - Assert.assertEquals(expectedCost, actualCost, 0.0); - } - } - - private static void testGettingTimeEstimateForSparkInst( - CostEstimator estimator, - HashMap inputStats, - SPInstruction targetInstruction, - String outputVar, - double expectedCost - ) throws CostEstimationException { - if (inputStats != null) - estimator.putStats(inputStats); - double actualCost = estimator.getTimeEstimateInst(targetInstruction); - RDDStats outputRDD = estimator.getStats(outputVar).getRddStats(); - if (outputRDD.isCollected()) { - // cost directly returned - if (expectedCost < 0) { - // check error-free cost estimation and meaningful result - Assert.assertTrue(actualCost > 0); - } else { - // check error-free cost estimation and exact result - Assert.assertEquals(expectedCost, actualCost, 0.0); - } - } else { - // cost saved in RDD statistics - double sparkCost = outputRDD.getCost(); - if (expectedCost < 0) { - // check error-free cost estimation and meaningful result - Assert.assertTrue(sparkCost > 0); - } else { - // check error-free cost estimation and exact result - Assert.assertEquals(expectedCost, sparkCost, 0.0); - } - } - } - + private static final HashMap instanceMap = getSimpleCloudInstanceMap(); + + private CostEstimator estimator; + + @Before + public void setup() { + ResourceCompiler.setDriverConfigurations(GBtoBytes(8), 4); + ResourceCompiler.setExecutorConfigurations(4, GBtoBytes(8), 4); + estimator = new CostEstimator(new Program(), instanceMap.get("m5.xlarge"), instanceMap.get("m5.xlarge")); + } + + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Tests for CP Instructions // + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + @Test + public void createvarMatrixVariableCPInstructionTest() throws CostEstimationException { + String instDefinition = "CP°createvar°testVar°testOutputFile°false°MATRIX°binary°100°100°1000°10000°COPY"; + VariableCPInstruction inst = VariableCPInstruction.parseInstruction(instDefinition); + testGettingTimeEstimateForCPInst(estimator, null, inst, 0); + // test the proper maintainCPInstVariableStatistics functionality + estimator.maintainStats(inst); + VarStats actualStats = estimator.getStats("testVar"); + Assert.assertNotNull(actualStats); + Assert.assertEquals(10000, actualStats.getCells()); + } + + @Test + public void createvarFrameVariableCPInstructionTest() throws CostEstimationException { + String instDefinition = "CP°createvar°testVar°testOutputFile°false°FRAME°binary°100°100°1000°10000°COPY"; + VariableCPInstruction inst = VariableCPInstruction.parseInstruction(instDefinition); + testGettingTimeEstimateForCPInst(estimator, null, inst, 0); + // test the proper maintainCPInstVariableStatistics functionality + estimator.maintainStats(inst); + VarStats actualStats = estimator.getStats("testVar"); + Assert.assertNotNull(actualStats); + Assert.assertEquals(10000, actualStats.getCells()); + } + + @Test + public void createvarInvalidVariableCPInstructionTest() throws CostEstimationException { + String instDefinition = "CP°createvar°testVar°testOutputFile°false°TENSOR°binary°100°100°1000°10000°copy"; + VariableCPInstruction inst = VariableCPInstruction.parseInstruction(instDefinition); + try { + estimator.maintainStats(inst); + testGettingTimeEstimateForCPInst(estimator, null, inst, 0); + Assert.fail("Tensor is not supported by the cost estimator"); + } catch (RuntimeException e) { + // needed catch block to assert that RuntimeException has been thrown + } + } + + @Test + public void randCPInstructionTest() throws CostEstimationException { + HashMap inputStats = new HashMap<>(); + inputStats.put("matrixVar", generateStats("matrixVar", 10000, 10000, -1)); + inputStats.put("outputVar", generateStats("outputVar", 10000, 10000, -1)); + + String instDefinition = "CP°+°scalarVar·SCALAR·FP64·false°matrixVar·MATRIX·FP64°outputVar·MATRIX·FP64"; + BinaryCPInstruction inst = BinaryCPInstruction.parseInstruction(instDefinition); + testGettingTimeEstimateForCPInst(estimator, inputStats, inst, -1); + } + + @Test + public void randCPInstructionExceedMemoryBudgetTest() { + HashMap inputStats = new HashMap<>(); + inputStats.put("matrixVar", generateStats("matrixVar", 1000000, 1000000, -1)); + inputStats.put("outputVar", generateStats("outputVar", 1000000, 1000000, -1)); + + String instDefinition = "CP°+°scalarVar·SCALAR·FP64·false°matrixVar·MATRIX·FP64°outputVar·MATRIX·FP64"; + BinaryCPInstruction inst = BinaryCPInstruction.parseInstruction(instDefinition); + try { + testGettingTimeEstimateForCPInst(estimator, inputStats, inst, -1); + Assert.fail("CostEstimationException should have been thrown for the given data size and instruction"); + } catch (CostEstimationException e) { + // needed catch block to assert that CostEstimationException has been thrown + } + } + + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Tests for Spark Instructions // + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + @Test + public void plusBinaryMatrixMatrixSpInstructionTest() throws CostEstimationException { + HashMap inputStats = new HashMap<>(); + inputStats.put("matrixVar", generateStatsWithRdd("matrixVar", 1000000,1000000, 500000000000L)); + inputStats.put("outputVar", generateStats("outputVar", 1000000,1000000, -1)); + + String instDefinition = "SPARK°+°scalarVar·SCALAR·FP64·false°matrixVar·MATRIX·FP64°outputVar·MATRIX·FP64"; + BinarySPInstruction inst = BinarySPInstruction.parseInstruction(instDefinition); + testGettingTimeEstimateForSparkInst(estimator, inputStats, inst, "outputVar", -1); + } + + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Helper methods for testing Instructions // + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + private VarStats generateStats(String name, long m, long n, long nnz) { + MatrixCharacteristics mc = new MatrixCharacteristics(m, n, nnz); + VarStats ret = new VarStats(name, mc); + long size = OptimizerUtils.estimateSizeExactSparsity(ret.getM(), ret.getN(), ret.getSparsity()); + ret.setAllocatedMemory(size); + return ret; + } + + private VarStats generateStatsWithRdd(String name, long m, long n, long nnz) { + MatrixCharacteristics mc = new MatrixCharacteristics(m, n, nnz); + VarStats stats = new VarStats(name, mc); + RDDStats rddStats = new RDDStats(stats); + stats.setRddStats(rddStats); + return stats; + } + + private static void testGettingTimeEstimateForCPInst( + CostEstimator estimator, + HashMap inputStats, + CPInstruction targetInstruction, + double expectedCost + ) throws CostEstimationException { + if (inputStats != null) + estimator.putStats(inputStats); + double actualCost = estimator.getTimeEstimateInst(targetInstruction); + + if (expectedCost < 0) { + // check error-free cost estimation and meaningful result + Assert.assertTrue(actualCost > 0); + } else { + // check error-free cost estimation and exact result + Assert.assertEquals(expectedCost, actualCost, 0.0); + } + } + + private static void testGettingTimeEstimateForSparkInst( + CostEstimator estimator, + HashMap inputStats, + SPInstruction targetInstruction, + String outputVar, + double expectedCost + ) throws CostEstimationException { + if (inputStats != null) + estimator.putStats(inputStats); + double actualCost = estimator.getTimeEstimateInst(targetInstruction); + RDDStats outputRDD = estimator.getStats(outputVar).getRddStats(); + if (outputRDD.isCollected()) { + // cost directly returned + if (expectedCost < 0) { + // check error-free cost estimation and meaningful result + Assert.assertTrue(actualCost > 0); + } else { + // check error-free cost estimation and exact result + Assert.assertEquals(expectedCost, actualCost, 0.0); + } + } else { + // cost saved in RDD statistics + double sparkCost = outputRDD.getCost(); + if (expectedCost < 0) { + // check error-free cost estimation and meaningful result + Assert.assertTrue(sparkCost > 0); + } else { + // check error-free cost estimation and exact result + Assert.assertEquals(expectedCost, sparkCost, 0.0); + } + } + } } diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImageSamplePairingLinearizedTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImageSamplePairingLinearizedTest.java index ef9f38606bb..35f70bd1232 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImageSamplePairingLinearizedTest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinImageSamplePairingLinearizedTest.java @@ -36,70 +36,70 @@ @net.jcip.annotations.NotThreadSafe public class BuiltinImageSamplePairingLinearizedTest extends AutomatedTestBase { - private final static String TEST_NAME_LINEARIZED = "image_sample_pairing_linearized"; - private final static String TEST_DIR = "functions/builtin/"; - private final static String TEST_CLASS_DIR = TEST_DIR + BuiltinImageSamplePairingLinearizedTest.class.getSimpleName() + "/"; - private final static double eps = 1e-10; - private final static double spSparse = 0.05; - private final static double spDense = 0.5; - - @Parameterized.Parameter() - public int value; - - @Parameterized.Parameters - public static Collection data() { - return Arrays.asList(new Object[][] { - {10}, - {-5}, - {3} - }); - } - - @Override - public void setUp() { - - addTestConfiguration(TEST_NAME_LINEARIZED, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_LINEARIZED, new String[]{"B_x"})); - } - - @Test - public void testImageSamplePairingLinearized() { - runImageSamplePairingLinearizedTest(false, ExecType.CP); - } - - private void runImageSamplePairingLinearizedTest(boolean sparse, ExecType instType) { - ExecMode platformOld = setExecMode(instType); - disableOutAndExpectedDeletion(); - - try { - loadTestConfiguration(getTestConfiguration(TEST_NAME_LINEARIZED)); - - double sparsity = sparse ? spSparse : spDense; - String HOME = SCRIPT_DIR + TEST_DIR; - - fullDMLScriptName = HOME + TEST_NAME_LINEARIZED + ".dml"; - programArgs = new String[]{"-nvargs", - "in_file=" + input("A"), - "in_file_second=" + input("secondMatrix"), - "x_out_reshape_file=" + output("B_x_reshape"), - "x_out_file=" + output("B_x"), - "value=" +value - }; - - double[][] A = getRandomMatrix(100,50, 0, 255, sparsity, 7); - double[][] secondMatrix = getRandomMatrix(1,50, 0, 255, sparsity, 7); - writeInputMatrixWithMTD("A", A, true); - writeInputMatrixWithMTD("secondMatrix", secondMatrix, true); - runTest(true, false, null, -1); - - HashMap dmlfileLinearizedX = readDMLMatrixFromOutputDir("B_x"); - - HashMap dmlfileX = readDMLMatrixFromOutputDir("B_x_reshape"); - - TestUtils.compareMatrices(dmlfileLinearizedX, dmlfileX, eps, "Stat-DML-LinearizedX", "Stat-DML-X"); - - - } finally { - rtplatform = platformOld; - } - } + private final static String TEST_NAME_LINEARIZED = "image_sample_pairing_linearized"; + private final static String TEST_DIR = "functions/builtin/"; + private final static String TEST_CLASS_DIR = TEST_DIR + BuiltinImageSamplePairingLinearizedTest.class.getSimpleName() + "/"; + private final static double eps = 1e-10; + private final static double spSparse = 0.05; + private final static double spDense = 0.5; + + @Parameterized.Parameter() + public int value; + + @Parameterized.Parameters + public static Collection data() { + return Arrays.asList(new Object[][] { + {10}, + {-5}, + {3} + }); + } + + @Override + public void setUp() { + + addTestConfiguration(TEST_NAME_LINEARIZED, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_LINEARIZED, new String[]{"B_x"})); + } + + @Test + public void testImageSamplePairingLinearized() { + runImageSamplePairingLinearizedTest(false, ExecType.CP); + } + + private void runImageSamplePairingLinearizedTest(boolean sparse, ExecType instType) { + ExecMode platformOld = setExecMode(instType); + disableOutAndExpectedDeletion(); + + try { + loadTestConfiguration(getTestConfiguration(TEST_NAME_LINEARIZED)); + + double sparsity = sparse ? spSparse : spDense; + String HOME = SCRIPT_DIR + TEST_DIR; + + fullDMLScriptName = HOME + TEST_NAME_LINEARIZED + ".dml"; + programArgs = new String[]{"-nvargs", + "in_file=" + input("A"), + "in_file_second=" + input("secondMatrix"), + "x_out_reshape_file=" + output("B_x_reshape"), + "x_out_file=" + output("B_x"), + "value=" +value + }; + + double[][] A = getRandomMatrix(100,50, 0, 255, sparsity, 7); + double[][] secondMatrix = getRandomMatrix(1,50, 0, 255, sparsity, 7); + writeInputMatrixWithMTD("A", A, true); + writeInputMatrixWithMTD("secondMatrix", secondMatrix, true); + runTest(true, false, null, -1); + + HashMap dmlfileLinearizedX = readDMLMatrixFromOutputDir("B_x"); + + HashMap dmlfileX = readDMLMatrixFromOutputDir("B_x_reshape"); + + TestUtils.compareMatrices(dmlfileLinearizedX, dmlfileX, eps, "Stat-DML-LinearizedX", "Stat-DML-X"); + + + } finally { + rtplatform = platformOld; + } + } } diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java index 44d06ffaf9a..ae7552cd16e 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java @@ -415,7 +415,6 @@ public void testTSNEEarlyStopping() throws IOException { runTSNEEarlyStoppingTest(2, 30, 300., 0.9, 1000, 1e-1, 1, "TRUE", 10, ExecType.CP); } - @SuppressWarnings("unused") private void runTSNEEarlyStoppingTest( Integer reduced_dims, Integer perplexity,