diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java index 90683eab292..df9fd84f779 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java @@ -62,6 +62,7 @@ import org.apache.sysds.runtime.meta.DataCharacteristics; import org.apache.sysds.runtime.transform.TfUtils; import org.apache.sysds.runtime.transform.encode.ColumnEncoder; +import org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords; import org.apache.sysds.runtime.transform.encode.ColumnEncoderBin; import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite; import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode; @@ -263,6 +264,7 @@ public Iterator> call(Iterator> // encoder-specific outputs List raEncoders = _encoder.getColumnEncoders(ColumnEncoderRecode.class); List baEncoders = _encoder.getColumnEncoders(ColumnEncoderBin.class); + List bowEncoders = _encoder.getColumnEncoders(ColumnEncoderBagOfWords.class); ArrayList> ret = new ArrayList<>(); // output recode maps as columnID - token pairs @@ -273,8 +275,14 @@ public Iterator> call(Iterator> for(Entry> e1 : tmp.entrySet()) for(Object token : e1.getValue()) ret.add(new Tuple2<>(e1.getKey(), token)); - if(!raEncoders.isEmpty()) - raEncoders.forEach(columnEncoderRecode -> columnEncoderRecode.getCPRecodeMapsPartial().clear()); + raEncoders.forEach(columnEncoderRecode -> columnEncoderRecode.getCPRecodeMapsPartial().clear()); + } + + if(!bowEncoders.isEmpty()){ + for (ColumnEncoderBagOfWords bowEnc : bowEncoders) + for (Object token : bowEnc.getPartialTokenDictionary()) + ret.add(new Tuple2<>(bowEnc.getColID(), token)); + bowEncoders.forEach(enc -> enc.getPartialTokenDictionary().clear()); } // output binning column min/max as columnID - min/max pairs @@ -321,7 +329,8 @@ public Iterator call(Tuple2> arg0) throws Exce StringBuilder sb = new StringBuilder(); // handle recode maps - if(_encoder.containsEncoderForID(colID, ColumnEncoderRecode.class)) { + if(_encoder.containsEncoderForID(colID, ColumnEncoderRecode.class) || + _encoder.containsEncoderForID(colID, ColumnEncoderBagOfWords.class)) { while(iter.hasNext()) { String token = TfUtils.sanitizeSpaces(iter.next().toString()); sb.append(rowID).append(' ').append(scolID).append(' '); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java index 61e6e799f04..4d4b0124445 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java @@ -88,6 +88,7 @@ import org.apache.sysds.runtime.transform.TfUtils.TfMethod; import org.apache.sysds.runtime.transform.decode.Decoder; import org.apache.sysds.runtime.transform.decode.DecoderFactory; +import org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords; import org.apache.sysds.runtime.transform.encode.EncoderFactory; import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; import org.apache.sysds.runtime.transform.meta.TfMetaUtils; @@ -1056,6 +1057,13 @@ public Tuple2 call(Tuple2 in) throws Excepti // execute block transform apply MultiColumnEncoder encoder = _bencoder.getValue(); + // we need to create a copy of the encoder since the bag of word encoder stores frameblock specific state + // which would be overwritten when multiple blocks are located on a executor + // to avoid this, we need to create a shallow copy of the MCEncoder, where we only instantiate new bow + // encoders objects with the frameblock specific fields and shallow copy the other fields (like meta) + // other encoders are reused and not newly instantiated + if(!encoder.getColumnEncoders(ColumnEncoderBagOfWords.class).isEmpty()) + encoder = new MultiColumnEncoder(encoder); // create copy MatrixBlock tmp = encoder.apply(blk); // remap keys if(_omap != null) { diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java index f10da3d9468..019df7f8470 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java @@ -450,7 +450,7 @@ protected void setApplyRowBlocksPerColumn(int nPart) { } public enum EncoderType { - Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, MVImpute, Composite, WordEmbedding, + Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, MVImpute, Composite, WordEmbedding, BagOfWords } /* diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBagOfWords.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBagOfWords.java index c138901ad1c..85673d86dba 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBagOfWords.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBagOfWords.java @@ -29,6 +29,9 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.utils.stats.TransformStatistics; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; @@ -46,7 +49,8 @@ public class ColumnEncoderBagOfWords extends ColumnEncoder { public static int NUM_SAMPLES_MAP_ESTIMATION = 16000; protected int[] nnzPerRow; - private Map tokenDictionary; + private Map tokenDictionary; // switched from int to long to reuse code from RecodeEncoder + private HashSet tokenDictionaryPart = null; protected String seperatorRegex = "\\s+"; // whitespace protected boolean caseSensitive = false; protected long nnz = 0; @@ -62,11 +66,43 @@ public ColumnEncoderBagOfWords() { super(-1); } + public ColumnEncoderBagOfWords(ColumnEncoderBagOfWords enc) { + super(enc._colID); + this.nnzPerRow = enc.nnzPerRow != null ? enc.nnzPerRow.clone() : null; + this.tokenDictionary = enc.tokenDictionary; + this.seperatorRegex = enc.seperatorRegex; + this.caseSensitive = enc.caseSensitive; + } + + public void setTokenDictionary(HashMap dict){ + this.tokenDictionary = dict; + } + + public Map getTokenDictionary() { + return tokenDictionary; + } + protected void initNnzPartials(int rows, int numBlocks){ this.nnzPerRow = new int[rows]; this.nnzPartials = new long[numBlocks]; } + public double computeNnzEstimate(CacheBlock in, int[] sampleIndices) { + // estimates the nnz per row for this encoder + final int max_index = Math.min(ColumnEncoderBagOfWords.NUM_SAMPLES_MAP_ESTIMATION, sampleIndices.length); + int nnz = 0; + for (int i = 0; i < max_index; i++) { + int sind = sampleIndices[i]; + String current = in.getString(sind, this._colID - 1); + if(current != null) + for(String token : tokenize(current, caseSensitive, seperatorRegex)) + if(!token.isEmpty() && tokenDictionary.containsKey(token)){ + nnz++; + } + } + return (double) nnz / max_index; + } + public void computeMapSizeEstimate(CacheBlock in, int[] sampleIndices) { // Find the frequencies of distinct values in the sample after tokenization HashMap distinctFreq = new HashMap<>(); @@ -118,6 +154,18 @@ public void computeMapSizeEstimate(CacheBlock in, int[] sampleIndices) { _estMetaSize = _estNumDistincts * _avgEntrySize; } + public void computeNnzPerRow(CacheBlock in, int start, int end){ + for (int i = start; i < end; i++) { + String current = in.getString(i, this._colID - 1); + HashSet distinctTokens = new HashSet<>(); + if(current != null) + for(String token : tokenize(current, caseSensitive, seperatorRegex)) + if(!token.isEmpty() && tokenDictionary.containsKey(token)) + distinctTokens.add(token); + this.nnzPerRow[i] = distinctTokens.size(); + } + } + public static String[] tokenize(String current, boolean caseSensitive, String seperatorRegex) { // string builder is faster than regex StringBuilder finalString = new StringBuilder(); @@ -150,8 +198,6 @@ protected TransformType getTransformType() { return TransformType.BAG_OF_WORDS; } - - public Callable getBuildTask(CacheBlock in) { return new ColumnBagOfWordsBuildTask(this, in); } @@ -160,7 +206,7 @@ public Callable getBuildTask(CacheBlock in) { public void build(CacheBlock in) { long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; tokenDictionary = new HashMap<>(_estNumDistincts); - int i = 0; + int i = 1; this.nnz = 0; nnzPerRow = new int[in.getNumRows()]; HashSet tokenSetPerRow; @@ -173,7 +219,7 @@ public void build(CacheBlock in) { if(!token.isEmpty()){ tokenSetPerRow.add(token); if(!this.tokenDictionary.containsKey(token)) - this.tokenDictionary.put(token, i++); + this.tokenDictionary.put(token, (long) i++); } this.nnzPerRow[r] = tokenSetPerRow.size(); this.nnz += tokenSetPerRow.size(); @@ -205,6 +251,33 @@ static class Pair { } } + @Override + public void prepareBuildPartial() { + // ensure allocated partial recode map + if(tokenDictionaryPart == null) + tokenDictionaryPart = new HashSet<>(); + } + + + public HashSet getPartialTokenDictionary(){ + return this.tokenDictionaryPart; + } + + @Override + public void buildPartial(FrameBlock in) { + if(!isApplicable()) + return; + for (int r = 0; r < in.getNumRows(); r++) { + String current = in.getString(r, this._colID - 1); + if(current != null) + for(String token : tokenize(current, caseSensitive, seperatorRegex)){ + if(!token.isEmpty()){ + tokenDictionaryPart.add(token); + } + } + } + } + protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk) { boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR; mcsr = false; // force CSR for transformencode FIXME @@ -214,7 +287,7 @@ protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int throw new NotImplementedException(); } else { // csr - HashMap counter = countTokenAppearances(in, r, _colID-1, caseSensitive, seperatorRegex); + HashMap counter = countTokenAppearances(in, r); if(counter.isEmpty()) sparseRowsWZeros.add(r); else { @@ -222,12 +295,13 @@ protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int int[] rptr = csrblock.rowPointers(); // assert that nnz from build is equal to nnz from apply assert counter.size() == nnzPerRow[r]; - Pair[] columnValuePairs = new Pair[counter.size()]; + Pair[] columnValuePairs = new Pair[nnzPerRow[r]]; int i = 0; for (Map.Entry entry : counter.entrySet()) { String token = entry.getKey(); - columnValuePairs[i] = new Pair(outputCol + tokenDictionary.get(token), entry.getValue()); - i++; + columnValuePairs[i] = new Pair((int) (outputCol + tokenDictionary.getOrDefault(token, 0L) - 1), entry.getValue()); + // if token is not included columnValuePairs[i] is overwritten in the next iteration + i += tokenDictionary.containsKey(token) ? 1 : 0; } // insertion sorts performs better on small arrays if(columnValuePairs.length >= 128) @@ -264,20 +338,20 @@ private static void insertionSort(Pair [] arr) { @Override protected void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk){ for (int r = rowStart; r < Math.max(in.getNumRows(), rowStart + blk); r++) { - HashMap counter = countTokenAppearances(in, r, _colID-1, caseSensitive, seperatorRegex); + HashMap counter = countTokenAppearances(in, r); for (Map.Entry entry : counter.entrySet()) - out.set(r, outputCol + tokenDictionary.get(entry.getKey()), entry.getValue()); + out.set(r, (int) (outputCol + tokenDictionary.get(entry.getKey()) - 1), entry.getValue()); } } - private static HashMap countTokenAppearances( - CacheBlock in, int r, int c, boolean caseSensitive, String separator) + private HashMap countTokenAppearances( + CacheBlock in, int r) { - String current = in.getString(r, c); + String current = in.getString(r, _colID - 1); HashMap counter = new HashMap<>(); if(current != null) - for (String token : tokenize(current, caseSensitive, separator)) - if (!token.isEmpty()) + for (String token : tokenize(current, caseSensitive, seperatorRegex)) + if (!token.isEmpty() && tokenDictionary.containsKey(token)) counter.put(token, counter.getOrDefault(token, 0) + 1); return counter; } @@ -291,15 +365,41 @@ public void allocateMetaData(FrameBlock meta) { public FrameBlock getMetaData(FrameBlock out) { int rowID = 0; StringBuilder sb = new StringBuilder(); - for(Map.Entry e : this.tokenDictionary.entrySet()) { - out.set(rowID++, _colID - 1, constructRecodeMapEntry(e.getKey(), Long.valueOf(e.getValue()), sb)); + for(Map.Entry e : this.tokenDictionary.entrySet()) { + out.set(rowID++, _colID - 1, constructRecodeMapEntry(e.getKey(), e.getValue(), sb)); } return out; } @Override public void initMetaData(FrameBlock meta) { - throw new NotImplementedException(); + if(meta != null && meta.getNumRows() > 0) { + this.tokenDictionary = meta.getRecodeMap(_colID - 1); + } + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + super.writeExternal(out); + + out.writeInt(tokenDictionary == null ? 0 : tokenDictionary.size()); + if(tokenDictionary != null) + for(Map.Entry e : tokenDictionary.entrySet()) { + out.writeUTF((String) e.getKey()); + out.writeLong(e.getValue()); + } + } + + @Override + public void readExternal(ObjectInput in) throws IOException { + super.readExternal(in); + int size = in.readInt(); + tokenDictionary = new HashMap<>(size * 4 / 3); + for(int j = 0; j < size; j++) { + String key = in.readUTF(); + Long value = in.readLong(); + tokenDictionary.put(key, value); + } } private static class BowPartialBuildTask implements Callable { @@ -378,11 +478,11 @@ private BowMergePartialBuildTask(ColumnEncoderBagOfWords encoderRecode, HashMap< @Override public Object call() { long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; - Map tokenDictionary = _encoder.tokenDictionary; + Map tokenDictionary = _encoder.tokenDictionary; for(Object tokenSet : _partialMaps.values()){ ( (HashSet) tokenSet).forEach(token -> { if(!tokenDictionary.containsKey(token)) - tokenDictionary.put((String) token, tokenDictionary.size()); + tokenDictionary.put(token, (long) tokenDictionary.size() + 1); }); } for (long nnzPartial : _encoder.nnzPartials) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java index 95443729146..536b387a1da 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java @@ -62,9 +62,9 @@ public ColumnEncoderComposite() { public ColumnEncoderComposite(List columnEncoders, FrameBlock meta) { super(-1); - if(!(columnEncoders.size() > 0 && + if(!(!columnEncoders.isEmpty() && columnEncoders.stream().allMatch((encoder -> encoder._colID == columnEncoders.get(0)._colID)))) - throw new DMLRuntimeException("Tried to create Composite Encoder with no encoders or mismatching columIDs"); + throw new DMLRuntimeException("Tried to create Composite Encoder with no encoders or mismatching columnIDs"); _colID = columnEncoders.get(0)._colID; _meta = meta; _columnEncoders = columnEncoders; @@ -73,6 +73,11 @@ public ColumnEncoderComposite(List columnEncoders, FrameBlock met public ColumnEncoderComposite(List columnEncoders) { this(columnEncoders, null); } + public ColumnEncoderComposite(List columnEncoders, int colID) { + super(colID); + _columnEncoders = columnEncoders; + _meta = null; + } public ColumnEncoderComposite(ColumnEncoder columnEncoder) { super(columnEncoder._colID); @@ -166,7 +171,8 @@ public List> getBuildTasks(CacheBlock in) { if(t == null) continue; // Linear execution between encoders so they can't be built in parallel - if(tasks.size() != 0) { + if(!tasks.isEmpty()) { + // TODO: is that still needed? currently there is no CompositeEncoder with 2 encoders with build phase // avoid unnecessary map initialization depMap = (depMap == null) ? new HashMap<>() : depMap; // This workaround is needed since sublist is only valid for effective final lists, @@ -207,6 +213,8 @@ public void buildPartial(FrameBlock in) { public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk) { try { for(int i = 0; i < _columnEncoders.size(); i++) { + // set sparseRowPointerOffset in the encoder + _columnEncoders.get(i).sparseRowPointerOffset = this.sparseRowPointerOffset; if(i == 0) { // 1. encoder writes data into MatrixBlock Column all others use this column for further encoding _columnEncoders.get(i).apply(in, out, outputCol, rowStart, blk); diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java index 05ad8e46940..1c2478d711b 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java @@ -265,6 +265,8 @@ else if(columnEncoder instanceof ColumnEncoderRecode) return EncoderType.Recode.ordinal(); else if(columnEncoder instanceof ColumnEncoderWordEmbedding) return EncoderType.WordEmbedding.ordinal(); + else if(columnEncoder instanceof ColumnEncoderBagOfWords) + return EncoderType.BagOfWords.ordinal(); throw new DMLRuntimeException("Unsupported encoder type: " + columnEncoder.getClass().getCanonicalName()); } @@ -283,6 +285,8 @@ public static ColumnEncoder createInstance(int type) { return new ColumnEncoderRecode(); case WordEmbedding: return new ColumnEncoderWordEmbedding(); + case BagOfWords: + return new ColumnEncoderBagOfWords(); default: throw new DMLRuntimeException("Unsupported encoder type: " + etype); } diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java index 0417e67ba1f..5e249826615 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java @@ -89,6 +89,20 @@ public MultiColumnEncoder(List columnEncoders) { _columnEncoders = columnEncoders; } + public MultiColumnEncoder(MultiColumnEncoder menc) { + // This constructor creates a shallow copy for all encoders except for bag_of_words encoders + List colEncs = menc._columnEncoders; + _columnEncoders= new ArrayList<>(); + for (ColumnEncoderComposite cColEnc : colEncs) { + List newEncs = new ArrayList<>(); + ColumnEncoderComposite cColEncCopy = new ColumnEncoderComposite(newEncs, cColEnc._colID); + _columnEncoders.add(cColEncCopy); + for (ColumnEncoder enc : cColEnc.getEncoders()) { + newEncs.add(enc instanceof ColumnEncoderBagOfWords ? new ColumnEncoderBagOfWords((ColumnEncoderBagOfWords) enc) : enc); + } + } + } + public MultiColumnEncoder() { _columnEncoders = new ArrayList<>(); } @@ -327,16 +341,14 @@ public MatrixBlock apply(CacheBlock in) { public MatrixBlock apply(CacheBlock in, int k) { // domain sizes are not updated if called from transformapply - boolean hasUDF = _columnEncoders.stream().anyMatch(e -> e.hasEncoder(ColumnEncoderUDF.class)); - boolean hasWE = _columnEncoders.stream().anyMatch(e -> e.hasEncoder(ColumnEncoderWordEmbedding.class)); - for(ColumnEncoderComposite columnEncoder : _columnEncoders) - columnEncoder.updateAllDCEncoders(); + EncoderMeta encm = getEncMeta(_columnEncoders, true, k, in); + updateAllDCEncoders(); int numCols = getNumOutCols(); - long estNNz = (long) in.getNumRows() * (hasUDF ? numCols : hasWE ? getEstNNzRow() : in.getNumColumns()); - // FIXME: estimate nnz for multiple encoders including dummycode and embedding - boolean sparse = MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz) && !hasUDF; + long estNNz = (long) in.getNumRows() * (encm.hasWE || encm.hasUDF ? numCols : (in.getNumColumns() - encm.numBOWEnc) + encm.nnzBOW); + // FIXME: estimate nnz for multiple encoders including dummycode + boolean sparse = MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz) && !encm.hasUDF; MatrixBlock out = new MatrixBlock(in.getNumRows(), numCols, sparse, estNNz); - return apply(in, out, 0, k); + return apply(in, out, 0, k, encm, estNNz); } public void updateAllDCEncoders(){ @@ -345,10 +357,11 @@ public void updateAllDCEncoders(){ } public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol) { - return apply(in, out, outputCol, 1); + // unused method, only exists currently because of the interface + throw new DMLRuntimeException("MultiColumnEncoder apply without Encoder Characteristics should not be called directly"); } - public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol, int k) { + public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol, int k, EncoderMeta encm, long nnz) { // There should be a encoder for every column if(hasLegacyEncoder() && !(in instanceof FrameBlock)) throw new DMLRuntimeException("LegacyEncoders do not support non FrameBlock Inputs"); @@ -361,31 +374,20 @@ public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol, int k if(in.getNumRows() == 0) throw new DMLRuntimeException("Invalid input with wrong number or rows"); - boolean hasDC = false; - boolean hasWE = false; - //TODO adapt transform apply for BOW - int distinctWE = 0; - int sizeWE = 0; - for(ColumnEncoderComposite columnEncoder : _columnEncoders) { - hasDC |= columnEncoder.hasEncoder(ColumnEncoderDummycode.class); - for (ColumnEncoder enc : columnEncoder.getEncoders()) - if(enc instanceof ColumnEncoderWordEmbedding){ - hasWE = true; - distinctWE = ((ColumnEncoderWordEmbedding) enc).getNrDistinctEmbeddings(); - sizeWE = enc.getDomainSize(); - } - } - outputMatrixPreProcessing(out, in, hasDC, hasWE, distinctWE, sizeWE, 0, null, -1); + ArrayList nnzOffsets = outputMatrixPreProcessing(out, in, encm, nnz, k); if(k > 1) { if(!_partitionDone) //happens if this method is directly called deriveNumRowPartitions(in, k); - applyMT(in, out, outputCol, k); + applyMT(in, out, outputCol, k, nnzOffsets); } else { - int offset = outputCol; + int offset = outputCol, i = 0; + int[] nnzOffset = null; for(ColumnEncoderComposite columnEncoder : _columnEncoders) { + columnEncoder.sparseRowPointerOffset = nnzOffset; columnEncoder.apply(in, out, columnEncoder._colID - 1 + offset); - offset = getOffset(offset, columnEncoder); + offset = getOutputColOffset(offset, columnEncoder); + nnzOffset = nnzOffsets != null ? nnzOffsets.get(i++) : null; } } // Recomputing NNZ since we access the block directly @@ -399,36 +401,44 @@ public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol, int k return out; } - private List> getApplyTasks(CacheBlock in, MatrixBlock out, int outputCol) { + private List> getApplyTasks(CacheBlock in, MatrixBlock out, int outputCol, ArrayList nnzOffsets) { List> tasks = new ArrayList<>(); int offset = outputCol; + int i = 0; + int[] currentNnzOffsets = null; for(ColumnEncoderComposite e : _columnEncoders) { - tasks.addAll(e.getApplyTasks(in, out, e._colID - 1 + offset, null)); - offset = getOffset(offset, e); + tasks.addAll(e.getApplyTasks(in, out, e._colID - 1 + offset, currentNnzOffsets)); + currentNnzOffsets = nnzOffsets != null ? nnzOffsets.get(i++) : null; + offset = getOutputColOffset(offset, e); } return tasks; } - private int getOffset(int offset, ColumnEncoderComposite e) { + private int getOutputColOffset(int offset, ColumnEncoderComposite e) { if(e.hasEncoder(ColumnEncoderDummycode.class)) offset += e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1; if(e.hasEncoder(ColumnEncoderWordEmbedding.class)) offset += e.getEncoder(ColumnEncoderWordEmbedding.class).getDomainSize() - 1; + if(e.hasEncoder(ColumnEncoderBagOfWords.class)) + offset += e.getEncoder(ColumnEncoderBagOfWords.class).getDomainSize() - 1; return offset; } - private void applyMT(CacheBlock in, MatrixBlock out, int outputCol, int k) { + private void applyMT(CacheBlock in, MatrixBlock out, int outputCol, int k, ArrayList nnzOffsets) { DependencyThreadPool pool = new DependencyThreadPool(k); try { if(APPLY_ENCODER_SEPARATE_STAGES) { int offset = outputCol; + int i = 0; + int[] currentNnzOffsets = null; for (ColumnEncoderComposite e : _columnEncoders) { - // for now bag of words is only used in encode - pool.submitAllAndWait(e.getApplyTasks(in, out, e._colID - 1 + offset, null)); - offset = getOffset(offset, e); + pool.submitAllAndWait(e.getApplyTasks(in, out, e._colID - 1 + offset, currentNnzOffsets)); + offset = getOutputColOffset(offset, e); + currentNnzOffsets = nnzOffsets != null ? nnzOffsets.get(i) : null; + i++; } } else - pool.submitAllAndWait(getApplyTasks(in, out, outputCol)); + pool.submitAllAndWait(getApplyTasks(in, out, outputCol, nnzOffsets)); } catch(ExecutionException | InterruptedException e) { throw new DMLRuntimeException(e); @@ -635,7 +645,7 @@ private void estimateMapSize(CacheBlock in, List encL } } - private int[] getSampleIndices(CacheBlock in, int sampleSize, int seed, int k){ + private static int[] getSampleIndices(CacheBlock in, int sampleSize, int seed, int k){ return ComEstSample.getSortedSample(in.getNumRows(), sampleSize, seed, k); } @@ -659,11 +669,11 @@ private long getTotalMemOverhead(CacheBlock in, int nBuildpart, List input, boolean hasDC, boolean hasWE, - int distinctWE, int sizeWE, int numBOW, int[] nnzPerRowBOW, int nnz) { + private static ArrayList outputMatrixPreProcessing(MatrixBlock output, CacheBlock input, EncoderMeta encm, long nnz, int k) { long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; if(nnz < 0) - nnz = output.getNumRows() * input.getNumColumns(); + nnz = (long) output.getNumRows() * input.getNumColumns(); + ArrayList bowNnzRowOffsets = null; if(output.isInSparseFormat()) { if (MatrixBlock.DEFAULT_SPARSEBLOCK != SparseBlock.Type.CSR && MatrixBlock.DEFAULT_SPARSEBLOCK != SparseBlock.Type.MCSR) @@ -673,7 +683,7 @@ private static void outputMatrixPreProcessing(MatrixBlock output, CacheBlock if (mcsr) { output.allocateBlock(); SparseBlock block = output.getSparseBlock(); - if (hasDC && OptimizerUtils.getTransformNumThreads()>1) { + if (encm.hasDC && OptimizerUtils.getTransformNumThreads()>1) { // DC forces a single threaded allocation after the build phase and // before the apply starts. Below code parallelizes sparse allocation. IntStream.range(0, output.getNumRows()) @@ -695,30 +705,52 @@ private static void outputMatrixPreProcessing(MatrixBlock output, CacheBlock } } else { //csr - SparseBlockCSR csrblock = new SparseBlockCSR(output.getNumRows(), nnz, nnz); // Manually fill the row pointers based on nnzs/row (= #cols in the input) // Not using the set() methods to 1) avoid binary search and shifting, // 2) reduce thread contentions on the arrays - int[] rptr = csrblock.rowPointers(); - if(nnzPerRowBOW != null) - for (int i=0; i LOG.debug("Elapsed time for allocation: "+ ((double) System.nanoTime() - t0) / 1000000 + " ms"); TransformStatistics.incOutMatrixPreProcessingTime(System.nanoTime()-t0); } + return bowNnzRowOffsets; + } + + private static ArrayList getNnzPerRowFromBOWEncoders(CacheBlock input, EncoderMeta encm, int k) { + ArrayList bowNnzRowOffsets; + int min_block_size = 1000; + int num_blocks = input.getNumRows() / min_block_size; + // 1 <= num_blks1 <= k / #enc + int num_blks1= Math.min( (k + encm.numBOWEnc - 1)/ encm.numBOWEnc, Math.max(num_blocks, 1)); + int blk_len1 = (input.getNumRows() + num_blks1 - 1) / num_blks1; + // 1 <= num_blks2 <= k + int num_blks2= Math.min(k, Math.max(num_blocks, 1)); + int blk_len2 = (input.getNumRows() + num_blks2 - 1) / num_blks1; + + ExecutorService pool = CommonThreadPool.get(k); + ArrayList bowNnzRowOffsetsFinal = new ArrayList<>(); + try { + encm.bowEncoders.forEach(e -> e.nnzPerRow = new int[input.getNumRows()]); + ArrayList> list = new ArrayList<>(); + for (int i = 0; i < num_blks1; i++) { + int start = i * blk_len1; + int end = Math.min((i + 1) * blk_len1, input.getNumRows()); + list.add(pool.submit(() -> encm.bowEncoders.stream().parallel().forEach(e -> e.computeNnzPerRow(input, start, end)))); + } + for(Future f : list) + f.get(); + list.clear(); + int[] previous = null; + for(ColumnEncoderComposite enc : encm.encs){ + if(enc.hasEncoder(ColumnEncoderBagOfWords.class)){ + previous = previous == null? enc.getEncoder(ColumnEncoderBagOfWords.class).nnzPerRow + : new int[input.getNumRows()]; + } + bowNnzRowOffsetsFinal.add(previous); + } + for (int i = 0; i < num_blks2; i++) { + int start = i * blk_len1; + list.add(pool.submit(() -> aggregateNnzPerRow(start, blk_len2, input.getNumRows(), + encm.encs, bowNnzRowOffsetsFinal))); + } + for(Future f : list) + f.get(); + bowNnzRowOffsets = bowNnzRowOffsetsFinal; + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + finally { + pool.shutdown(); + } + return bowNnzRowOffsets; + } + + private static void aggregateNnzPerRow(int start, int blk_len, int numRows, List encs, ArrayList bowNnzRowOffsets) { + int end = Math.min(start + blk_len, numRows); + int pos = 0; + int[] aggRowOffsets = null; + for(ColumnEncoderComposite enc : encs){ + int[] currentOffsets = bowNnzRowOffsets.get(pos); + if (enc.hasEncoder(ColumnEncoderBagOfWords.class)) { + ColumnEncoderBagOfWords bow = enc.getEncoder(ColumnEncoderBagOfWords.class); + if(aggRowOffsets == null){ + aggRowOffsets = currentOffsets; + } else { + for (int i = start; i < end; i++) { + currentOffsets[i] = aggRowOffsets[i] + bow.nnzPerRow[i] - 1; + } + } + } + pos++; + } } private void outputMatrixPostProcessing(MatrixBlock output, int k){ @@ -822,7 +925,7 @@ public FrameBlock getMetaData(FrameBlock meta, int k) { tasks.add(new ColumnMetaDataTask<>(columnEncoder, meta)); List> taskret = pool.invokeAll(tasks); for (Future task : taskret) - task.get(); + task.get(); } catch(Exception ex) { throw new DMLRuntimeException(ex); @@ -1021,13 +1124,6 @@ public List> getEncoderTypes() { return getEncoderTypes(-1); } - public int getEstNNzRow(){ - int nnz = 0; - for(int i = 0; i < _columnEncoders.size(); i++) - nnz += _columnEncoders.get(i).getDomainSize(); - return nnz; - } - public int getNumOutCols() { int sum = 0; for(int i = 0; i < _columnEncoders.size(); i++) @@ -1218,6 +1314,92 @@ public String toString() { return sb.toString(); } + private static class EncoderMeta { + // contains information about the encoders and their relevant data characteristics + public final boolean hasUDF; + public final boolean hasDC; + public final boolean hasWE; + public final int distinctWE; + public final int sizeWE; + public final long nnzBOW; + public final int numBOWEnc; + public final int[] nnzPerRowBOW; + public final ArrayList bowEncoders; + public final List encs; + + public EncoderMeta(boolean hasUDF, boolean hasDC, boolean hasWE, int distinctWE, int sizeWE, long nnzBOW, + int numBOWEncoder, int[] nnzPerRowBOW, ArrayList bows, + List encoders) { + this.hasUDF = hasUDF; + this.hasDC = hasDC; + this.hasWE = hasWE; + this.distinctWE = distinctWE; + this.sizeWE = sizeWE; + this.nnzBOW = nnzBOW; + this.numBOWEnc = numBOWEncoder; + this.nnzPerRowBOW = nnzPerRowBOW; + this.bowEncoders = bows; + this.encs = encoders; + } + } + + private static EncoderMeta getEncMeta(List encoders, boolean noBuild, int k, CacheBlock in) { + boolean hasUDF = false, hasDC = false, hasWE = false; + int distinctWE = 0; + int sizeWE = 0; + long nnzBOW = 0; + int numBOWEncoder = 0; + int[] nnzPerRowBOW = null; + ArrayList bows = new ArrayList<>(); + for (ColumnEncoderComposite enc : encoders){ + if(enc.hasEncoder(ColumnEncoderUDF.class)) + hasUDF = true; + else if (enc.hasEncoder(ColumnEncoderDummycode.class)) + hasDC = true; + else if(enc.hasEncoder(ColumnEncoderBagOfWords.class)){ + ColumnEncoderBagOfWords bowEnc = enc.getEncoder(ColumnEncoderBagOfWords.class); + numBOWEncoder++; + nnzBOW += bowEnc.nnz; + if(noBuild){ + // estimate nnz by sampling + bows.add(bowEnc); + } else if(nnzPerRowBOW != null) + for (int i = 0; i < bowEnc.nnzPerRow.length; i++) { + nnzPerRowBOW[i] += bowEnc.nnzPerRow[i]; + } + else { + nnzPerRowBOW = bowEnc.nnzPerRow.clone(); + } + } + else if(enc.hasEncoder(ColumnEncoderWordEmbedding.class)){ + hasWE = true; + distinctWE = enc.getEncoder(ColumnEncoderWordEmbedding.class).getNrDistinctEmbeddings(); + sizeWE = enc.getDomainSize(); + } + } + if(!bows.isEmpty()){ + int[] sampleInds = getSampleIndices(in, in.getNumRows() > 1000 ? (int) (0.1 * in.getNumRows()) : in.getNumRows(), (int) System.nanoTime(), 1); + // Concurrent (column-wise) bag of words nnz estimation per row, we estimate the number of nnz because the + // exact number is only needed for sparse outputs not for dense, if sparse, we recount the nnz for all rows later + // Note: the sampling might be problematic since we used for the sparsity estimation -> which impacts performance + // if we go for the non-ideal output format + ExecutorService pool = CommonThreadPool.get(k); + try { + Double result = pool.submit(() -> bows.stream().parallel() + .mapToDouble(e -> e.computeNnzEstimate(in, sampleInds)) + .sum()).get(); + nnzBOW = (long) Math.ceil(result); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + finally{ + pool.shutdown(); + } + } + return new EncoderMeta(hasUDF, hasDC, hasWE, distinctWE, sizeWE, nnzBOW, numBOWEncoder, nnzPerRowBOW, bows, encoders); + } + /* * Currently, not in use will be integrated in the future */ @@ -1271,41 +1453,12 @@ private InitOutputMatrixTask(MultiColumnEncoder encoder, CacheBlock input, Ma @Override public Object call() { - boolean hasUDF = false, hasDC = false, hasWE = false; - int distinctWE = 0; - int sizeWE = 0; - long nnzBOW = 0; - int numBOWEncoder = 0; - int[] nnzPerRowBOW = null; - for (ColumnEncoderComposite enc : _encoder.getEncoders()){ - if(enc.hasEncoder(ColumnEncoderUDF.class)) - hasUDF = true; - else if (enc.hasEncoder(ColumnEncoderDummycode.class)) - hasDC = true; - else if(enc.hasEncoder(ColumnEncoderBagOfWords.class)){ - ColumnEncoderBagOfWords bowEnc = enc.getEncoder(ColumnEncoderBagOfWords.class); - numBOWEncoder++; - nnzBOW += bowEnc.nnz; - if(nnzPerRowBOW != null) - for (int i = 0; i < bowEnc.nnzPerRow.length; i++) { - nnzPerRowBOW[i] += bowEnc.nnzPerRow[i]; - } - else { - nnzPerRowBOW = bowEnc.nnzPerRow.clone(); - } - } - else if(enc.hasEncoder(ColumnEncoderWordEmbedding.class)){ - hasWE = true; - distinctWE = enc.getEncoder(ColumnEncoderWordEmbedding.class).getNrDistinctEmbeddings(); - sizeWE = enc.getDomainSize(); - } - } - + EncoderMeta encm = getEncMeta(_encoder.getEncoders(), false, -1, _input); int numCols = _encoder.getNumOutCols(); - long estNNz = (long) _input.getNumRows() * (hasUDF ? numCols : _input.getNumColumns() - numBOWEncoder) + nnzBOW; - boolean sparse = MatrixBlock.evalSparseFormatInMemory(_input.getNumRows(), numCols, estNNz) && !hasUDF; + long estNNz = (long) _input.getNumRows() * (encm.hasUDF ? numCols : _input.getNumColumns() - encm.numBOWEnc) + encm.nnzBOW; + boolean sparse = MatrixBlock.evalSparseFormatInMemory(_input.getNumRows(), numCols, estNNz) && !encm.hasUDF; _output.reset(_input.getNumRows(), numCols, sparse, estNNz); - outputMatrixPreProcessing(_output, _input, hasDC, hasWE, distinctWE, sizeWE, numBOWEncoder, nnzPerRowBOW, (int) estNNz); + outputMatrixPreProcessing(_output, _input, encm, estNNz, 1); return null; } @@ -1381,13 +1534,14 @@ public String toString() { @Override public Object call() throws Exception { + // updates the outputCol offset and sets the nnz offsets, which are created by bow encoders, in each encoder int currentCol = -1; int currentOffset = 0; int[] sparseRowPointerOffsets = null; for(DependencyTask dtask : _applyTasksWrappers) { ((ApplyTasksWrapperTask) dtask).setOffset(currentOffset); if(sparseRowPointerOffsets != null) - ((ApplyTasksWrapperTask) dtask).setSparseRowPointerOffsets(sparseRowPointerOffsets.clone()); + ((ApplyTasksWrapperTask) dtask).setSparseRowPointerOffsets(sparseRowPointerOffsets); int nonOffsetCol = ((ApplyTasksWrapperTask) dtask)._encoder._colID - 1; if(nonOffsetCol > currentCol) { currentCol = nonOffsetCol; @@ -1398,11 +1552,14 @@ else if (enc.hasEncoder(ColumnEncoderBagOfWords.class)) { ColumnEncoderBagOfWords bow = enc.getEncoder(ColumnEncoderBagOfWords.class); currentOffset += bow.getDomainSize() - 1; if(sparseRowPointerOffsets == null) - sparseRowPointerOffsets = bow.nnzPerRow.clone(); - else + sparseRowPointerOffsets = bow.nnzPerRow; + else{ + sparseRowPointerOffsets = sparseRowPointerOffsets.clone(); + // TODO: experiment if it makes sense to parallize here (for frames with many rows) for (int r = 0; r < sparseRowPointerOffsets.length; r++) { sparseRowPointerOffsets[r] += bow.nnzPerRow[r] - 1; } + } } } } diff --git a/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderMixedFunctionalityTests.java b/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderMixedFunctionalityTests.java new file mode 100644 index 00000000000..2e5f852ca8f --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderMixedFunctionalityTests.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.transform; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.transform.encode.ColumnEncoder; +import org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords; +import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite; +import org.apache.sysds.runtime.transform.encode.EncoderFactory; +import org.apache.sysds.runtime.transform.encode.EncoderOmit; +import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class ColumnEncoderMixedFunctionalityTests extends AutomatedTestBase +{ + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + } + + @Test + public void testCompositeConstructor1() { + ColumnEncoderComposite cEnc1 = new ColumnEncoderComposite(null, 1); + ColumnEncoderComposite cEnc2 = new ColumnEncoderComposite(cEnc1); + assert cEnc1.getColID() == cEnc2.getColID(); + + } + @Test + public void testCompositeConstructor2() { + List encoderList = new ArrayList<>(); + encoderList.add( new ColumnEncoderComposite(null, 1)); + encoderList.add( new ColumnEncoderComposite(null, 2)); + DMLRuntimeException e = assertThrows(DMLRuntimeException.class, () -> new ColumnEncoderComposite(encoderList, null)); + assertTrue(e.getMessage().contains("Tried to create Composite Encoder with no encoders or mismatching columnIDs")); + } + + @Test + public void testEncoderFactoryGetUnsupportedEncoderType(){ + DMLRuntimeException e = assertThrows(DMLRuntimeException.class, () -> EncoderFactory.getEncoderType(new ColumnEncoderComposite())); + assertTrue(e.getMessage().contains("Unsupported encoder type: org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite")); + } + + @Test + public void testEncoderFactoryCreateUnsupportedInstanceType(){ + // type(7) = composite, which we don't use for encoding the type + DMLRuntimeException e = assertThrows(DMLRuntimeException.class, () -> EncoderFactory.createInstance(7)); + assertTrue(e.getMessage().contains("Unsupported encoder type: Composite")); + } + + @Test + public void testMultiColumnEncoderApplyWithWrongInputCharacteristics1(){ + // apply call without metadata about encoders + MultiColumnEncoder mEnc = new MultiColumnEncoder(); + DMLRuntimeException e = assertThrows(DMLRuntimeException.class, () -> mEnc.apply(null, null, 0)); + assertTrue(e.getMessage().contains("MultiColumnEncoder apply without Encoder Characteristics should not be called directly")); + } + + @Test + public void testMultiColumnEncoderApplyWithWrongInputCharacteristics2(){ + // apply with LegacyEncoders + non FrameBlock Inputs + MultiColumnEncoder mEnc = new MultiColumnEncoder(); + mEnc.addReplaceLegacyEncoder(new EncoderOmit()); + DMLRuntimeException e = assertThrows(DMLRuntimeException.class, () -> mEnc.apply(new MatrixBlock(), null, 0, 0, null, 0L)); + assertTrue(e.getMessage().contains("LegacyEncoders do not support non FrameBlock Inputs")); + } + + @Test + public void testMultiColumnEncoderApplyWithWrongInputCharacteristics3(){ + // #CompositeEncoders != #cols + ArrayList encs = new ArrayList<>(); + encs.add(new ColumnEncoderBagOfWords()); + ArrayList cEncs = new ArrayList<>(); + cEncs.add(new ColumnEncoderComposite(encs)); + MultiColumnEncoder mEnc = new MultiColumnEncoder(cEncs); + FrameBlock in = new FrameBlock(2, Types.ValueType.FP64); + DMLRuntimeException e = assertThrows(DMLRuntimeException.class, () -> mEnc.apply(in, null, 0, 0, null, 0L)); + assertTrue(e.getMessage().contains("Not every column in has a CompositeEncoder. Please make sure every column has a encoder or slice the input accordingly")); + } + + @Test + public void testMultiColumnEncoderApplyWithWrongInputCharacteristics4(){ + // input has 0 rows + MultiColumnEncoder mEnc = new MultiColumnEncoder(); + MatrixBlock in = new MatrixBlock(); + DMLRuntimeException e = assertThrows(DMLRuntimeException.class, () -> mEnc.apply(in, null, 0, 0, null, 0L)); + assertTrue(e.getMessage().contains("Invalid input with wrong number or rows")); + } + + + +} \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderSerializationTest.java b/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderSerializationTest.java index aeec927e73b..2bd1e646978 100644 --- a/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderSerializationTest.java +++ b/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderSerializationTest.java @@ -25,11 +25,14 @@ import java.io.ObjectInput; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.sysds.common.Types; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.transform.encode.ColumnEncoder; +import org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords; import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite; import org.apache.sysds.runtime.transform.encode.EncoderFactory; import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; @@ -56,7 +59,8 @@ public enum TransformType { RECODE, DUMMY, IMPUTE, - OMIT + OMIT, + BOW } @Override @@ -88,6 +92,12 @@ public void setUp() { @Test public void testComposite8() { runTransformSerTest(TransformType.OMIT, schemaStrings); } + @Test + public void testComposite9() { runTransformSerTest(TransformType.BOW, schemaStrings); } + + @Test + public void testComposite10() { runTransformSerTest(TransformType.BOW, schemaMixed); } + @@ -117,11 +127,21 @@ else if(type == TransformType.IMPUTE) "{ \"id\": 7, \"method\": \"global_mode\" }, { \"id\": 9, \"method\": \"global_mean\" } ]\n\n}"; else if (type == TransformType.OMIT) spec = "{ \"ids\": true, \"omit\": [ 1,2,4,5,6,7,8,9 ], \"recode\": [ 2, 7 ] }"; + else if (type == TransformType.BOW) + spec = "{ \"ids\": true, \"omit\": [ 1,4,5,6,8,9 ], \"bag_of_words\": [ 2, 7 ] }"; frame.setSchema(schema); String[] cnames = frame.getColumnNames(); MultiColumnEncoder encoderIn = EncoderFactory.createEncoder(spec, cnames, frame.getNumColumns(), null); + if(type == TransformType.BOW){ + List encs = encoderIn.getColumnEncoders(ColumnEncoderBagOfWords.class); + HashMap dict = new HashMap<>(); + dict.put("val1", 1L); + dict.put("val2", 2L); + dict.put("val3", 300L); + encs.forEach(e -> e.setTokenDictionary(dict)); + } MultiColumnEncoder encoderOut; // serialization and deserialization @@ -141,7 +161,16 @@ else if (type == TransformType.OMIT) for(Class classtype: typesIn){ Assert.assertArrayEquals(encoderIn.getFromAllIntArray(classtype, ColumnEncoder::getColID), encoderOut.getFromAllIntArray(classtype, ColumnEncoder::getColID)); } - + if(type == TransformType.BOW){ + List encsIn = encoderIn.getColumnEncoders(ColumnEncoderBagOfWords.class); + List encsOut = encoderOut.getColumnEncoders(ColumnEncoderBagOfWords.class); + for (int i = 0; i < encsIn.size(); i++) { + Map eOutDict = encsOut.get(i).getTokenDictionary(); + encsIn.get(i).getTokenDictionary().forEach((k,v) -> { + assert v.equals(eOutDict.get(k)); + }); + } + } } diff --git a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeBagOfWords.java b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeBagOfWords.java index e3d1c07be26..631ff7747ef 100644 --- a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeBagOfWords.java +++ b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeBagOfWords.java @@ -26,6 +26,7 @@ import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixValue; import org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords; +import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; @@ -36,6 +37,7 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -47,6 +49,7 @@ public class TransformFrameEncodeBagOfWords extends AutomatedTestBase { private final static String TEST_NAME1 = "TransformFrameEncodeBagOfWords"; + private final static String TEST_NAME2 = "TransformFrameEncodeApplyBagOfWords"; private final static String TEST_DIR = "functions/transform/"; private final static String TEST_CLASS_DIR = TEST_DIR + TransformFrameEncodeBagOfWords.class.getSimpleName() + "/"; // for benchmarking: Digital_Music_Text.csv @@ -56,6 +59,7 @@ public class TransformFrameEncodeBagOfWords extends AutomatedTestBase public void setUp() { TestUtils.clearAssertionInformation(); addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1)); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2)); } // These tests result in dense output @@ -63,6 +67,17 @@ public void setUp() { public void testTransformBagOfWords() { runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE, false, false); } + @Test + public void testTransformApplyBagOfWords() { + runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, false, false); + } + + @Test + public void testTransformApplySeparateStagesBagOfWords() { + MultiColumnEncoder.APPLY_ENCODER_SEPARATE_STAGES = true; + runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, false, false); + MultiColumnEncoder.APPLY_ENCODER_SEPARATE_STAGES = false; + } @Test public void testTransformBagOfWordsError() { @@ -74,32 +89,62 @@ public void testTransformBagOfWordsPlusRecode() { runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE, true, false); } + @Test + public void testTransformApplyBagOfWordsPlusRecode() { + runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, true, false); + } + @Test public void testTransformBagOfWords2() { runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE, false, true); } + @Test + public void testTransformApplyBagOfWords2() { + runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, false, true); + } + @Test public void testTransformBagOfWordsPlusRecode2() { runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE, true, true); } + @Test + public void testTransformApplyBagOfWordsPlusRecode2() { + runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, true, true); + } + // AmazonReviewDataset transformation results in a sparse output @Test public void testTransformBagOfWordsAmazonReviews() { runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE, false, false, true); } + @Test + public void testTransformApplyBagOfWordsAmazonReviews() { + runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, false, false, true); + } + @Test public void testTransformBagOfWordsAmazonReviews2() { runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE, false, true, true); } + @Test + public void testTransformApplyBagOfWordsAmazonReviews2() { + runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, false, true, true); + } + @Test public void testTransformBagOfWordsAmazonReviewsAndRandRecode() { runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE, true, false, true); } + @Test + public void testTransformApplyBagOfWordsAmazonReviewsAndRandRecode() { + runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, true, false, true); + } + @Test public void testTransformBagOfWordsAmazonReviewsAndDummyCode() { // TODO: compare result @@ -118,16 +163,55 @@ public void testTransformBagOfWordsAmazonReviewsAndRandRecode2() { } @Test - public void testNotImplementedFunction(){ - ColumnEncoderBagOfWords bow = new ColumnEncoderBagOfWords(); - assertThrows(NotImplementedException.class, () -> bow.initMetaData(null)); + public void testTransformApplyBagOfWordsAmazonReviewsAndRandRecode2() { + runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, true, true, true); + } + + @Test + public void testTransformApplySeparateStagesBagOfWordsAmazonReviewsAndRandRecode2() { + MultiColumnEncoder.APPLY_ENCODER_SEPARATE_STAGES = true; + runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, true, true, true); + MultiColumnEncoder.APPLY_ENCODER_SEPARATE_STAGES = false; } - //@Test + + @Test public void testTransformBagOfWordsSpark() { runTransformTest(TEST_NAME1, ExecMode.SPARK, false, false); } + @Test + public void testTransformBagOfWordsAmazonReviewsSpark() { + runTransformTest(TEST_NAME1, ExecMode.SPARK, false, false, true); + } + + @Test + public void testTransformBagOfWordsAmazonReviews2Spark() { + runTransformTest(TEST_NAME1, ExecMode.SPARK, false, true, true); + } + + @Test + public void testTransformBagOfWordsAmazonReviewsAndRandRecodeSpark() { + runTransformTest(TEST_NAME1, ExecMode.SPARK, true, false, true); + } + + @Test + public void testTransformBagOfWordsAmazonReviewsAndRandRecode2Spark() { + runTransformTest(TEST_NAME1, ExecMode.SPARK, true, true, true); + } + + @Test + public void testBuildPartialBagOfWordsNotApplicable() { + ColumnEncoderBagOfWords bow = new ColumnEncoderBagOfWords(); + assert bow.getColID() == -1; + try { + bow.buildPartial(null); // should run without error + } catch (Exception e) { + throw new AssertionError("Test failed: Expected no errors due to early abort (colId = -1). " + + "Encountered exception:\n" + e + "\nMessage: " + Arrays.toString(e.getStackTrace())); + } + } + private void runTransformTest(String testname, ExecMode rt, boolean recode, boolean dup){ runTransformTest(testname, rt, recode, dup, false); } @@ -154,31 +238,31 @@ private void runTransformTest(String testname, ExecMode rt, boolean recode, bool if(!fromFile) writeStringsToCsvFile(sentenceColumn, recodeColumn, baseDirectory + INPUT_DIR + "data", dup); - int mode = 0; - if(error) - mode = 1; - if(dc) - mode = 2; - if(pt) - mode = 3; - programArgs = new String[]{"-stats","-args", fromFile ? DATASET_DIR + DATASET : input("data"), + int mode = error ? 1 : (dc ? 2 : (pt ? 3 : 0)); + programArgs = new String[]{"-explain", "recompile_runtime", "-stats","-args", fromFile ? DATASET_DIR + DATASET : input("data"), output("result"), output("dict"), String.valueOf(recode), String.valueOf(dup), String.valueOf(fromFile), String.valueOf(mode)}; if(error) runTest(true, EXCEPTION_EXPECTED, DMLRuntimeException.class, -1); else{ runTest(true, EXCEPTION_NOT_EXPECTED, null, -1); - FrameBlock dict_frame = readDMLFrameFromHDFS( "dict", Types.FileFormat.CSV); - int cols = recode? dict_frame.getNumRows() + 1 : dict_frame.getNumRows(); - if(dup) - cols *= 2; - if(mode == 0){ - HashMap res_actual = readDMLMatrixFromOutputDir("result"); - double[][] result = TestUtils.convertHashMapToDoubleArray(res_actual, Math.min(sentenceColumn.length, 100), - cols); - checkResults(sentenceColumn, result, recodeColumn, dict_frame, dup ? 2 : 1); + if(testname == TEST_NAME2){ + double errorValue = readDMLScalarFromOutputDir("result").values() + .stream().findFirst().orElse(1000.0); + System.out.println(errorValue); + assert errorValue <= 10; + } else { + FrameBlock dict_frame = readDMLFrameFromHDFS( "dict", Types.FileFormat.CSV); + int cols = recode? dict_frame.getNumRows() + 1 : dict_frame.getNumRows(); + if(dup) + cols *= 2; + if(mode == 0){ + HashMap res_actual = readDMLMatrixFromOutputDir("result"); + double[][] result = TestUtils.convertHashMapToDoubleArray(res_actual, Math.min(sentenceColumn.length, 100), + cols); + checkResults(sentenceColumn, result, recodeColumn, dict_frame, dup ? 2 : 1); + } } - } @@ -211,6 +295,9 @@ public static void checkResults(String[] sentences, double[][] result, { HashMap[] indices = new HashMap[duplicates]; HashMap[] rcdMaps = new HashMap[duplicates]; + String errors = ""; + int num_errors = 0; + int max_errors = 100; int frameCol = 0; // even when the set of tokens is the same for duplicates, the order in which the tokens dicts are merged // is not always the same for all columns in multithreaded mode @@ -219,8 +306,9 @@ public static void checkResults(String[] sentences, double[][] result, rcdMaps[i] = new HashMap<>(); for (int j = 0; j < dict.getNumRows(); j++) { String[] tuple = dict.getString(j, frameCol).split("\u00b7"); - indices[i].put(tuple[0], Integer.parseInt(tuple[1])); + indices[i].put(tuple[0], Integer.parseInt(tuple[1]) - 1); } + System.out.println("Bow dict size: " + indices[i].size()); frameCol++; if(recodeColumn != null){ for (int j = 0; j < dict.getNumRows(); j++) { @@ -232,6 +320,7 @@ public static void checkResults(String[] sentences, double[][] result, } frameCol++; } + System.out.println("Rec dict size: " + rcdMaps[i].size()); } // only check the first 100 rows @@ -266,19 +355,49 @@ public static void checkResults(String[] sentences, double[][] result, for(Map.Entry entry : count.entrySet()){ String word = entry.getKey(); int count_expected = entry.getValue(); - int index = indices[j].get(word); - assert result[row][index + offset] == count_expected; + Integer index = indices[j].get(word); + if(index == null){ + throw new AssertionError("row [" + row + "]: not found word: " + word); + } + if(result[row][index + offset] != count_expected){ + String error_message = "bow result[" + row + "," + (index + offset) + "]=" + + result[row][index + offset] + " does not match the expected: " + count_expected; + if(num_errors < max_errors) + errors += error_message + '\n'; + else + throw new AssertionError(errors + error_message); + num_errors++; + } + } + for(int zeroIndex : zeroIndices){ + if(result[row][offset + zeroIndex] != 0){ + String error_message = "bow result[" + row + "," + (offset + zeroIndex) + "]=" + + result[row][offset + zeroIndex] + " does not match the expected: 0"; + if(num_errors < max_errors) + errors += error_message + '\n'; + else + throw new AssertionError(errors + error_message); + num_errors++; + } } - for(int zeroIndex : zeroIndices) - assert result[row][offset + zeroIndex] == 0; offset += indices[j].size(); // compare results: recode if(recodeColumn != null){ - assert result[row][offset] == rcdMaps[j].get(recodeColumn[row]); + if(result[row][offset] != rcdMaps[j].get(recodeColumn[row])){ + String error_message = "recode result[" + row + "," + offset + "]=" + + result[row][offset]+ " does not match the expected: " + rcdMaps[j].get(recodeColumn[row]); + if(num_errors < max_errors) + errors += error_message + '\n'; + else + throw new AssertionError(errors + error_message); + num_errors++; + } offset++; } } } + if (num_errors > 0) + throw new AssertionError(errors); } public static void writeStringsToCsvFile(String[] sentences, String[] recodeTokens, String fileName, boolean duplicate) throws IOException { diff --git a/src/test/scripts/functions/transform/TransformFrameEncodeApplyBagOfWords.dml b/src/test/scripts/functions/transform/TransformFrameEncodeApplyBagOfWords.dml new file mode 100644 index 00000000000..b2cd0ccda2c --- /dev/null +++ b/src/test/scripts/functions/transform/TransformFrameEncodeApplyBagOfWords.dml @@ -0,0 +1,84 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read the token sequence (1K) w/ 100 distinct tokens +Data = read($1, data_type="frame", format="csv"); + +if(!as.boolean($4) & as.boolean($6)){ + Data = Data[,1] +} +if(as.boolean($5) & as.boolean($6)){ + Data = cbind(Data,Data) +} +while(FALSE){} + +if (as.boolean($4)) { + if (as.boolean($5)) { + jspec = "{ids: true, bag_of_words: [1,3], recode : [2,4]}"; + } else { + jspec = "{ids: true, bag_of_words: [1], recode : [2]}"; + } +} else { + if (as.boolean($5)) { + jspec = "{ids: true, bag_of_words: [1,2]}"; + } else { + jspec = "{ids: true, bag_of_words: [1]}"; + } +} +if(as.integer($7) == 1){ + jspec = "{ids: true, bag_of_words: [1], recode : [1]}"; +} +if(as.integer($7) == 2){ + jspec = "{ids: true, bag_of_words: [1], dummycode : [2]}"; +} +if(as.integer($7) == 3){ + ones = as.frame(matrix(1, nrow(Data), 1)) + Data = cbind(Data, ones) + jspec = "{ids: true, bag_of_words: [1]}"; +} + +[Data_enc, Meta] = transformencode(target=Data, spec=jspec); +while(FALSE){} + +i = 0 +total = 0 +j = 0 +# set to 20 for benchmarking +while(i < 30){ + t0 = time() + Data_enc2 = transformapply(target=Data, spec=jspec, meta=Meta) + if(i > 10){ + total = total + time() - t0 + j = j + 1 + } + i = i + 1 +} +print(total/1000000000 / j) + +i = 0 + +Error = sign(Data_enc2 - Data_enc) +Error_agg = sum(Error * Error) +#print(sum(sign(Data_enc2))) +#print(sum(sign(Data_enc))) +#print(Error_agg) +write(Error_agg, $2, format="text"); + diff --git a/src/test/scripts/functions/transform/TransformFrameEncodeBagOfWords.dml b/src/test/scripts/functions/transform/TransformFrameEncodeBagOfWords.dml index 49231c97c0b..2a69f314a43 100644 --- a/src/test/scripts/functions/transform/TransformFrameEncodeBagOfWords.dml +++ b/src/test/scripts/functions/transform/TransformFrameEncodeBagOfWords.dml @@ -21,15 +21,15 @@ # Read the token sequence (1K) w/ 100 distinct tokens Data = read($1, data_type="frame", format="csv"); +#print(toString(Data)) if(!as.boolean($4) & as.boolean($6)){ Data = Data[,1] } +while(FALSE){} if(as.boolean($5) & as.boolean($6)){ Data = cbind(Data,Data) } -while(FALSE){} - if (as.boolean($4)) { if (as.boolean($5)) { jspec = "{ids: true, bag_of_words: [1,3], recode : [2,4]}"; @@ -54,7 +54,7 @@ if(as.integer($7) == 3){ Data = cbind(Data, ones) jspec = "{ids: true, bag_of_words: [1]}"; } - +while(FALSE){} i = 0 total = 0 j = 0