Skip to content

Commit

Permalink
[SYSTEMDS-3782] Bag-of-words encoder for SP
Browse files Browse the repository at this point in the history
  • Loading branch information
e-strauss committed Nov 22, 2024
1 parent 4e00aa1 commit 48cb3c5
Show file tree
Hide file tree
Showing 12 changed files with 800 additions and 161 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.transform.TfUtils;
import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderBin;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
Expand Down Expand Up @@ -263,6 +264,7 @@ public Iterator<Tuple2<Integer, Object>> call(Iterator<Tuple2<Long, FrameBlock>>
// encoder-specific outputs
List<ColumnEncoderRecode> raEncoders = _encoder.getColumnEncoders(ColumnEncoderRecode.class);
List<ColumnEncoderBin> baEncoders = _encoder.getColumnEncoders(ColumnEncoderBin.class);
List<ColumnEncoderBagOfWords> bowEncoders = _encoder.getColumnEncoders(ColumnEncoderBagOfWords.class);
ArrayList<Tuple2<Integer, Object>> ret = new ArrayList<>();

// output recode maps as columnID - token pairs
Expand All @@ -273,8 +275,14 @@ public Iterator<Tuple2<Integer, Object>> call(Iterator<Tuple2<Long, FrameBlock>>
for(Entry<Integer, HashSet<Object>> e1 : tmp.entrySet())
for(Object token : e1.getValue())
ret.add(new Tuple2<>(e1.getKey(), token));
if(!raEncoders.isEmpty())
raEncoders.forEach(columnEncoderRecode -> columnEncoderRecode.getCPRecodeMapsPartial().clear());
raEncoders.forEach(columnEncoderRecode -> columnEncoderRecode.getCPRecodeMapsPartial().clear());
}

if(!bowEncoders.isEmpty()){
for (ColumnEncoderBagOfWords bowEnc : bowEncoders)
for (Object token : bowEnc.getPartialTokenDictionary())
ret.add(new Tuple2<>(bowEnc.getColID(), token));
bowEncoders.forEach(enc -> enc.getPartialTokenDictionary().clear());
}

// output binning column min/max as columnID - min/max pairs
Expand Down Expand Up @@ -321,7 +329,8 @@ public Iterator<String> call(Tuple2<Integer, Iterable<Object>> arg0) throws Exce
StringBuilder sb = new StringBuilder();

// handle recode maps
if(_encoder.containsEncoderForID(colID, ColumnEncoderRecode.class)) {
if(_encoder.containsEncoderForID(colID, ColumnEncoderRecode.class) ||
_encoder.containsEncoderForID(colID, ColumnEncoderBagOfWords.class)) {
while(iter.hasNext()) {
String token = TfUtils.sanitizeSpaces(iter.next().toString());
sb.append(rowID).append(' ').append(scolID).append(' ');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
import org.apache.sysds.runtime.transform.decode.Decoder;
import org.apache.sysds.runtime.transform.decode.DecoderFactory;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
Expand Down Expand Up @@ -1056,6 +1057,13 @@ public Tuple2<Long, FrameBlock> call(Tuple2<Long, FrameBlock> in) throws Excepti

// execute block transform apply
MultiColumnEncoder encoder = _bencoder.getValue();
// we need to create a copy of the encoder since the bag of word encoder stores frameblock specific state
// which would be overwritten when multiple blocks are located on a executor
// to avoid this, we need to create a shallow copy of the MCEncoder, where we only instantiate new bow
// encoders objects with the frameblock specific fields and shallow copy the other fields (like meta)
// other encoders are reused and not newly instantiated
if(!encoder.getColumnEncoders(ColumnEncoderBagOfWords.class).isEmpty())
encoder = new MultiColumnEncoder(encoder); // create copy
MatrixBlock tmp = encoder.apply(blk);
// remap keys
if(_omap != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ protected void setApplyRowBlocksPerColumn(int nPart) {
}

public enum EncoderType {
Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, MVImpute, Composite, WordEmbedding,
Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, MVImpute, Composite, WordEmbedding, BagOfWords
}

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.utils.stats.TransformStatistics;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
Expand All @@ -46,7 +49,8 @@ public class ColumnEncoderBagOfWords extends ColumnEncoder {

public static int NUM_SAMPLES_MAP_ESTIMATION = 16000;
protected int[] nnzPerRow;
private Map<String, Integer> tokenDictionary;
private Map<Object, Long> tokenDictionary; // switched from int to long to reuse code from RecodeEncoder
private HashSet<Object> tokenDictionaryPart = null;
protected String seperatorRegex = "\\s+"; // whitespace
protected boolean caseSensitive = false;
protected long nnz = 0;
Expand All @@ -62,11 +66,43 @@ public ColumnEncoderBagOfWords() {
super(-1);
}

public ColumnEncoderBagOfWords(ColumnEncoderBagOfWords enc) {
super(enc._colID);
this.nnzPerRow = enc.nnzPerRow != null ? enc.nnzPerRow.clone() : null;
this.tokenDictionary = enc.tokenDictionary;
this.seperatorRegex = enc.seperatorRegex;
this.caseSensitive = enc.caseSensitive;
}

public void setTokenDictionary(HashMap<Object, Long> dict){
this.tokenDictionary = dict;
}

public Map<Object, Long> getTokenDictionary() {
return tokenDictionary;
}

protected void initNnzPartials(int rows, int numBlocks){
this.nnzPerRow = new int[rows];
this.nnzPartials = new long[numBlocks];
}

public double computeNnzEstimate(CacheBlock<?> in, int[] sampleIndices) {
// estimates the nnz per row for this encoder
final int max_index = Math.min(ColumnEncoderBagOfWords.NUM_SAMPLES_MAP_ESTIMATION, sampleIndices.length);
int nnz = 0;
for (int i = 0; i < max_index; i++) {
int sind = sampleIndices[i];
String current = in.getString(sind, this._colID - 1);
if(current != null)
for(String token : tokenize(current, caseSensitive, seperatorRegex))
if(!token.isEmpty() && tokenDictionary.containsKey(token)){
nnz++;
}
}
return (double) nnz / max_index;
}

public void computeMapSizeEstimate(CacheBlock<?> in, int[] sampleIndices) {
// Find the frequencies of distinct values in the sample after tokenization
HashMap<String, Integer> distinctFreq = new HashMap<>();
Expand Down Expand Up @@ -118,6 +154,18 @@ public void computeMapSizeEstimate(CacheBlock<?> in, int[] sampleIndices) {
_estMetaSize = _estNumDistincts * _avgEntrySize;
}

public void computeNnzPerRow(CacheBlock<?> in, int start, int end){
for (int i = start; i < end; i++) {
String current = in.getString(i, this._colID - 1);
HashSet<String> distinctTokens = new HashSet<>();
if(current != null)
for(String token : tokenize(current, caseSensitive, seperatorRegex))
if(!token.isEmpty() && tokenDictionary.containsKey(token))
distinctTokens.add(token);
this.nnzPerRow[i] = distinctTokens.size();
}
}

public static String[] tokenize(String current, boolean caseSensitive, String seperatorRegex) {
// string builder is faster than regex
StringBuilder finalString = new StringBuilder();
Expand Down Expand Up @@ -150,8 +198,6 @@ protected TransformType getTransformType() {
return TransformType.BAG_OF_WORDS;
}



public Callable<Object> getBuildTask(CacheBlock<?> in) {
return new ColumnBagOfWordsBuildTask(this, in);
}
Expand All @@ -160,7 +206,7 @@ public Callable<Object> getBuildTask(CacheBlock<?> in) {
public void build(CacheBlock<?> in) {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
tokenDictionary = new HashMap<>(_estNumDistincts);
int i = 0;
int i = 1;
this.nnz = 0;
nnzPerRow = new int[in.getNumRows()];
HashSet<String> tokenSetPerRow;
Expand All @@ -173,7 +219,7 @@ public void build(CacheBlock<?> in) {
if(!token.isEmpty()){
tokenSetPerRow.add(token);
if(!this.tokenDictionary.containsKey(token))
this.tokenDictionary.put(token, i++);
this.tokenDictionary.put(token, (long) i++);
}
this.nnzPerRow[r] = tokenSetPerRow.size();
this.nnz += tokenSetPerRow.size();
Expand Down Expand Up @@ -205,6 +251,33 @@ static class Pair {
}
}

@Override
public void prepareBuildPartial() {
// ensure allocated partial recode map
if(tokenDictionaryPart == null)
tokenDictionaryPart = new HashSet<>();
}


public HashSet<Object> getPartialTokenDictionary(){
return this.tokenDictionaryPart;
}

@Override
public void buildPartial(FrameBlock in) {
if(!isApplicable())
return;
for (int r = 0; r < in.getNumRows(); r++) {
String current = in.getString(r, this._colID - 1);
if(current != null)
for(String token : tokenize(current, caseSensitive, seperatorRegex)){
if(!token.isEmpty()){
tokenDictionaryPart.add(token);
}
}
}
}

protected void applySparse(CacheBlock<?> in, MatrixBlock out, int outputCol, int rowStart, int blk) {
boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR;
mcsr = false; // force CSR for transformencode FIXME
Expand All @@ -214,20 +287,21 @@ protected void applySparse(CacheBlock<?> in, MatrixBlock out, int outputCol, int
throw new NotImplementedException();
}
else { // csr
HashMap<String, Integer> counter = countTokenAppearances(in, r, _colID-1, caseSensitive, seperatorRegex);
HashMap<String, Integer> counter = countTokenAppearances(in, r);
if(counter.isEmpty())
sparseRowsWZeros.add(r);
else {
SparseBlockCSR csrblock = (SparseBlockCSR) out.getSparseBlock();
int[] rptr = csrblock.rowPointers();
// assert that nnz from build is equal to nnz from apply
assert counter.size() == nnzPerRow[r];
Pair[] columnValuePairs = new Pair[counter.size()];
Pair[] columnValuePairs = new Pair[nnzPerRow[r]];
int i = 0;
for (Map.Entry<String, Integer> entry : counter.entrySet()) {
String token = entry.getKey();
columnValuePairs[i] = new Pair(outputCol + tokenDictionary.get(token), entry.getValue());
i++;
columnValuePairs[i] = new Pair((int) (outputCol + tokenDictionary.getOrDefault(token, 0L) - 1), entry.getValue());
// if token is not included columnValuePairs[i] is overwritten in the next iteration
i += tokenDictionary.containsKey(token) ? 1 : 0;
}
// insertion sorts performs better on small arrays
if(columnValuePairs.length >= 128)
Expand Down Expand Up @@ -264,20 +338,20 @@ private static void insertionSort(Pair [] arr) {
@Override
protected void applyDense(CacheBlock<?> in, MatrixBlock out, int outputCol, int rowStart, int blk){
for (int r = rowStart; r < Math.max(in.getNumRows(), rowStart + blk); r++) {
HashMap<String, Integer> counter = countTokenAppearances(in, r, _colID-1, caseSensitive, seperatorRegex);
HashMap<String, Integer> counter = countTokenAppearances(in, r);
for (Map.Entry<String, Integer> entry : counter.entrySet())
out.set(r, outputCol + tokenDictionary.get(entry.getKey()), entry.getValue());
out.set(r, (int) (outputCol + tokenDictionary.get(entry.getKey()) - 1), entry.getValue());
}
}

private static HashMap<String, Integer> countTokenAppearances(
CacheBlock<?> in, int r, int c, boolean caseSensitive, String separator)
private HashMap<String, Integer> countTokenAppearances(
CacheBlock<?> in, int r)
{
String current = in.getString(r, c);
String current = in.getString(r, _colID - 1);
HashMap<String, Integer> counter = new HashMap<>();
if(current != null)
for (String token : tokenize(current, caseSensitive, separator))
if (!token.isEmpty())
for (String token : tokenize(current, caseSensitive, seperatorRegex))
if (!token.isEmpty() && tokenDictionary.containsKey(token))
counter.put(token, counter.getOrDefault(token, 0) + 1);
return counter;
}
Expand All @@ -291,15 +365,41 @@ public void allocateMetaData(FrameBlock meta) {
public FrameBlock getMetaData(FrameBlock out) {
int rowID = 0;
StringBuilder sb = new StringBuilder();
for(Map.Entry<String, Integer> e : this.tokenDictionary.entrySet()) {
out.set(rowID++, _colID - 1, constructRecodeMapEntry(e.getKey(), Long.valueOf(e.getValue()), sb));
for(Map.Entry<Object, Long> e : this.tokenDictionary.entrySet()) {
out.set(rowID++, _colID - 1, constructRecodeMapEntry(e.getKey(), e.getValue(), sb));
}
return out;
}

@Override
public void initMetaData(FrameBlock meta) {
throw new NotImplementedException();
if(meta != null && meta.getNumRows() > 0) {
this.tokenDictionary = meta.getRecodeMap(_colID - 1);
}
}

@Override
public void writeExternal(ObjectOutput out) throws IOException {
super.writeExternal(out);

out.writeInt(tokenDictionary == null ? 0 : tokenDictionary.size());
if(tokenDictionary != null)
for(Map.Entry<Object, Long> e : tokenDictionary.entrySet()) {
out.writeUTF((String) e.getKey());
out.writeLong(e.getValue());
}
}

@Override
public void readExternal(ObjectInput in) throws IOException {
super.readExternal(in);
int size = in.readInt();
tokenDictionary = new HashMap<>(size * 4 / 3);
for(int j = 0; j < size; j++) {
String key = in.readUTF();
Long value = in.readLong();
tokenDictionary.put(key, value);
}
}

private static class BowPartialBuildTask implements Callable<Object> {
Expand Down Expand Up @@ -378,11 +478,11 @@ private BowMergePartialBuildTask(ColumnEncoderBagOfWords encoderRecode, HashMap<
@Override
public Object call() {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
Map<String, Integer> tokenDictionary = _encoder.tokenDictionary;
Map<Object, Long> tokenDictionary = _encoder.tokenDictionary;
for(Object tokenSet : _partialMaps.values()){
( (HashSet<?>) tokenSet).forEach(token -> {
if(!tokenDictionary.containsKey(token))
tokenDictionary.put((String) token, tokenDictionary.size());
tokenDictionary.put(token, (long) tokenDictionary.size() + 1);
});
}
for (long nnzPartial : _encoder.nnzPartials)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ public ColumnEncoderComposite() {

public ColumnEncoderComposite(List<ColumnEncoder> columnEncoders, FrameBlock meta) {
super(-1);
if(!(columnEncoders.size() > 0 &&
if(!(!columnEncoders.isEmpty() &&
columnEncoders.stream().allMatch((encoder -> encoder._colID == columnEncoders.get(0)._colID))))
throw new DMLRuntimeException("Tried to create Composite Encoder with no encoders or mismatching columIDs");
throw new DMLRuntimeException("Tried to create Composite Encoder with no encoders or mismatching columnIDs");
_colID = columnEncoders.get(0)._colID;
_meta = meta;
_columnEncoders = columnEncoders;
Expand All @@ -73,6 +73,11 @@ public ColumnEncoderComposite(List<ColumnEncoder> columnEncoders, FrameBlock met
public ColumnEncoderComposite(List<ColumnEncoder> columnEncoders) {
this(columnEncoders, null);
}
public ColumnEncoderComposite(List<ColumnEncoder> columnEncoders, int colID) {
super(colID);
_columnEncoders = columnEncoders;
_meta = null;
}

public ColumnEncoderComposite(ColumnEncoder columnEncoder) {
super(columnEncoder._colID);
Expand Down Expand Up @@ -166,7 +171,8 @@ public List<DependencyTask<?>> getBuildTasks(CacheBlock<?> in) {
if(t == null)
continue;
// Linear execution between encoders so they can't be built in parallel
if(tasks.size() != 0) {
if(!tasks.isEmpty()) {
// TODO: is that still needed? currently there is no CompositeEncoder with 2 encoders with build phase
// avoid unnecessary map initialization
depMap = (depMap == null) ? new HashMap<>() : depMap;
// This workaround is needed since sublist is only valid for effective final lists,
Expand Down Expand Up @@ -207,6 +213,8 @@ public void buildPartial(FrameBlock in) {
public MatrixBlock apply(CacheBlock<?> in, MatrixBlock out, int outputCol, int rowStart, int blk) {
try {
for(int i = 0; i < _columnEncoders.size(); i++) {
// set sparseRowPointerOffset in the encoder
_columnEncoders.get(i).sparseRowPointerOffset = this.sparseRowPointerOffset;
if(i == 0) {
// 1. encoder writes data into MatrixBlock Column all others use this column for further encoding
_columnEncoders.get(i).apply(in, out, outputCol, rowStart, blk);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ else if(columnEncoder instanceof ColumnEncoderRecode)
return EncoderType.Recode.ordinal();
else if(columnEncoder instanceof ColumnEncoderWordEmbedding)
return EncoderType.WordEmbedding.ordinal();
else if(columnEncoder instanceof ColumnEncoderBagOfWords)
return EncoderType.BagOfWords.ordinal();
throw new DMLRuntimeException("Unsupported encoder type: " + columnEncoder.getClass().getCanonicalName());
}

Expand All @@ -283,6 +285,8 @@ public static ColumnEncoder createInstance(int type) {
return new ColumnEncoderRecode();
case WordEmbedding:
return new ColumnEncoderWordEmbedding();
case BagOfWords:
return new ColumnEncoderBagOfWords();
default:
throw new DMLRuntimeException("Unsupported encoder type: " + etype);
}
Expand Down
Loading

0 comments on commit 48cb3c5

Please sign in to comment.