Skip to content

Commit

Permalink
Implement the user suggestion POC
Browse files Browse the repository at this point in the history
  • Loading branch information
maurever committed Nov 1, 2023
1 parent 54149b7 commit 71115a0
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 46 deletions.
53 changes: 36 additions & 17 deletions h2o-core/src/main/java/water/rapids/ast/prims/search/AstMatch.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
public class AstMatch extends AstPrimitive {
@Override
public String[] args() {
return new String[]{"ary", "table", "nomatch", "incomparables"};
return new String[]{"ary", "table", "nomatch", "incomparables", "indexes"};
}

@Override
public int nargs() {
return 1 + 4;
} // (match fr table nomatch incomps)
return 1 + 5;
} // (match fr table nomatch incomps indexes)

@Override
public String str() {
Expand All @@ -45,18 +45,22 @@ public ValFrame apply(Env env, Env.StackHelp stk, AstRoot asts[]) {

final MRTask<?> matchTask;
double noMatch = asts[3].exec(env).getNum();
boolean indexes = asts[5].exec(env).getBool();


if (asts[2] instanceof AstNumList) {
matchTask = new NumMatchTask(((AstNumList) asts[2]).sort().expand(), noMatch);
matchTask = new NumMatchTask(((AstNumList) asts[2]).sort().expand(), noMatch, indexes);
} else if (asts[2] instanceof AstNum) {
matchTask = new NumMatchTask(new double[]{asts[2].exec(env).getNum()}, noMatch);
matchTask = new NumMatchTask(new double[]{asts[2].exec(env).getNum()}, noMatch, indexes);
} else if (asts[2] instanceof AstStrList) {
String[] values = ((AstStrList) asts[2])._strs;
Arrays.sort(values);
matchTask = fr.anyVec().isString() ? new StrMatchTask(values, noMatch) : new CatMatchTask(values, noMatch);
matchTask = fr.anyVec().isString() ? new StrMatchTask(values, noMatch, indexes) :
new CatMatchTask(values, noMatch, indexes);
} else if (asts[2] instanceof AstStr) {
String[] values = new String[]{asts[2].exec(env).getStr()};
matchTask = fr.anyVec().isString() ? new StrMatchTask(values, noMatch) : new CatMatchTask(values, noMatch);
matchTask = fr.anyVec().isString() ? new StrMatchTask(values, noMatch, indexes) :
new CatMatchTask(values, noMatch, indexes);
} else
throw new IllegalArgumentException("Expected numbers/strings. Got: " + asts[2].getClass());

Expand All @@ -67,34 +71,43 @@ public ValFrame apply(Env env, Env.StackHelp stk, AstRoot asts[]) {
private static class StrMatchTask extends MRTask<CatMatchTask> {
String[] _values;
double _noMatch;
StrMatchTask(String[] values, double noMatch) {
boolean _indexes;

StrMatchTask(String[] values, double noMatch, boolean indexes) {
_values = values;
_noMatch = noMatch;
_indexes = indexes;
}
@Override
public void map(Chunk c, NewChunk nc) {
BufferedString bs = new BufferedString();
int rows = c._len;
for (int r = 0; r < rows; r++) {
double x = c.isNA(r) ? _noMatch : in(_values, c.atStr(bs, r).toString(), _noMatch);
double x = c.isNA(r) ? _noMatch : in(_values, c.atStr(bs, r).toString(), _noMatch, _indexes);
nc.addNum(x);
}
}
}

private static class CatMatchTask extends MRTask<CatMatchTask> {
String[] _values;
int[] _firstMatchRow;
double _noMatch;
CatMatchTask(String[] values, double noMatch) {
boolean _indexes;

CatMatchTask(String[] values, double noMatch, boolean indexes) {
_values = values;
_noMatch = noMatch;
_indexes = indexes;
_firstMatchRow = new int[values.length];
}

@Override
public void map(Chunk c, NewChunk nc) {
String[] domain = c.vec().domain();
int rows = c._len;
for (int r = 0; r < rows; r++) {
double x = c.isNA(r) ? _noMatch : in(_values, domain[(int) c.at8(r)], _noMatch);
double x = c.isNA(r) ? _noMatch : in(_values, domain[(int) c.at8(r)], _noMatch, _indexes);
nc.addNum(x);
}
}
Expand All @@ -103,26 +116,32 @@ public void map(Chunk c, NewChunk nc) {
private static class NumMatchTask extends MRTask<CatMatchTask> {
double[] _values;
double _noMatch;
NumMatchTask(double[] values, double noMatch) {
boolean _indexes;

NumMatchTask(double[] values, double noMatch, boolean indexes) {
_values = values;
_noMatch = noMatch;
_indexes = indexes;
}

@Override
public void map(Chunk c, NewChunk nc) {
int rows = c._len;
for (int r = 0; r < rows; r++) {
double x = c.isNA(r) ? _noMatch : in(_values, c.atd(r), _noMatch);
double x = c.isNA(r) ? _noMatch : in(_values, c.atd(r), _noMatch, _indexes);
nc.addNum(x);
}
}
}

private static double in(String[] matches, String s, double nomatch) {
return Arrays.binarySearch(matches, s) >= 0 ? 1 : nomatch;
private static double in(String[] matches, String s, double nomatch, boolean indexes) {
int match = Arrays.binarySearch(matches, s);
return match >= 0 ? indexes ? match : 1 : nomatch;
}

private static double in(double[] matches, double d, double nomatch) {
return binarySearchDoublesUlp(matches, 0, matches.length, d) >= 0 ? 1 : nomatch;
private static double in(double[] matches, double d, double nomatch, boolean indexes) {
int match = binarySearchDoublesUlp(matches, 0, matches.length, d);
return match >= 0 ? indexes ? match : 1 : nomatch;
}

private static int binarySearchDoublesUlp(double[] a, int from, int to, double key) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ public class AstMatchTest extends TestUtil {
static public void setup() { stall_till_cloudsize(1); }

@Test
public void testMatchNumList() throws Exception {
public void testMatchNumList() {
final Frame data = makeTestFrame();
Frame output = null;
try {
String numList = idx(data.vec(2), "cB", "cC", "cD");
String rapids = "(tmp= tst (match (cols data [2]) [" + numList + "] -1 ignored))";
String numList = idx(data.vec(3), "cB", "cC", "cD");
String rapids = "(tmp= tst (match (cols data [3]) [" + numList + "] -1 ignored false))";
Val val = Rapids.exec(rapids);
output = val.getFrame();
assertVecEquals(data.vec(0), output.vec(0), 0.0);
Expand All @@ -35,11 +35,11 @@ public void testMatchNumList() throws Exception {
}

@Test
public void testMatchCatList() throws Exception {
public void testMatchCatList() {
final Frame data = makeTestFrame();
Frame output = null;
try {
String rapids = "(tmp= tst (match (cols data [2]) [\"cD\",\"cC\",\"cB\"] -1 ignored))";
String rapids = "(tmp= tst (match (cols data [3]) [\"cD\",\"cC\",\"cB\"] -1 ignored false))";
Val val = Rapids.exec(rapids);
output = val.getFrame();
assertVecEquals(data.vec(0), output.vec(0), 0.0);
Expand All @@ -52,11 +52,11 @@ public void testMatchCatList() throws Exception {
}

@Test
public void testMatchStrList() throws Exception {
public void testMatchStrList() {
final Frame data = makeTestFrame();
Frame output = null;
try {
String rapids = "(tmp= tst (match (cols data [1]) [\"sD\",\"sC\",\"sB\"] -1 ignored))";
String rapids = "(tmp= tst (match (cols data [2]) [\"sD\",\"sC\",\"sB\"] -1 ignored false))";
Val val = Rapids.exec(rapids);
output = val.getFrame();
assertVecEquals(data.vec(0), output.vec(0), 0.0);
Expand All @@ -68,26 +68,81 @@ public void testMatchStrList() throws Exception {
}
}

@Test
public void testMatchNumListIndexes() {
final Frame data = makeTestFrame();
Frame output = null;
try {
String numList = idx(data.vec(3), "cB", "cC", "cD");
String rapids = "(tmp= tst (match (cols data [3]) [" + numList + "] -1 ignored true))";
Val val = Rapids.exec(rapids);
output = val.getFrame();
assertVecEquals(data.vec(1), output.vec(0), 0.0);
} finally {
data.delete();
if (output != null) {
output.delete();
}
}
}

@Test
public void testMatchCatListIndexes() {
final Frame data = makeTestFrame();
Frame output = null;
try {
String rapids = "(tmp= tst (match (cols data [3]) [\"cB\",\"cC\",\"cD\"] -1 ignored true))";
Val val = Rapids.exec(rapids);
output = val.getFrame();
assertVecEquals(data.vec(1), output.vec(0), 0.0);
} finally {
data.delete();
if (output != null) {
output.delete();
}
}
}

@Test
public void testMatchStrListIndexes() {
final Frame data = makeTestFrame();
Frame output = null;
try {
String rapids = "(tmp= tst (match (cols data [2]) [\"sB\",\"sC\",\"sD\"] -1 ignored true))";
Val val = Rapids.exec(rapids);
output = val.getFrame();
assertVecEquals(data.vec(1), output.vec(0), 0.0);
} finally {
data.delete();
if (output != null) {
output.delete();
}
}
}

private Frame makeTestFrame() {
Random rnd = new Random();
final int len = 45000;
final int len = 55000;
double numData[] = new double[len];
double[] numDataIndexes = new double[len];
String[] strData = new String[len];
String[] catData = new String[len];
for (int i = 0; i < len; i++) {
char c = (char) ('A' + rnd.nextInt('Z' - 'A'));
numData[i] = c >= 'B' && c <= 'D' ? 1 : -1;
strData[i] = "s" + Character.toString(c);
catData[i] = "c" + Character.toString(c);
numDataIndexes[i] = c == 'B' ? 0 : c == 'C' ? 1 : c == 'D' ? 2 : -1;
strData[i] = "s" + c;
catData[i] = "c" + c;
}
return new TestFrameBuilder()
.withName("data")
.withColNames("Expected", "Str", "Cat")
.withVecTypes(Vec.T_NUM, Vec.T_STR, Vec.T_CAT)
.withColNames("Expected", "Expected_idexes", "Str", "Cat")
.withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_STR, Vec.T_CAT)
.withDataForCol(0, numData)
.withDataForCol(1, strData)
.withDataForCol(2, catData)
.withChunkLayout(10000, 10000, 20000, 5000)
.withDataForCol(1, numDataIndexes)
.withDataForCol(2, strData)
.withDataForCol(3, catData)
.withChunkLayout(10000, 10000, 10000, 20000, 5000)
.build();
}

Expand All @@ -106,4 +161,4 @@ private String idx(Vec v, String... cats) {
return sb.toString();
}

}
}
6 changes: 4 additions & 2 deletions h2o-py/h2o/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4643,14 +4643,16 @@ def stratified_split(self, test_frac=0.2, seed=-1):
"""
return H2OFrame._expr(expr=ExprNode('h2o.random_stratified_split', self, test_frac, seed))

def match(self, table, nomatch=0):
def match(self, table, nomatch=0, indexes=False):
"""
Make a vector of the positions of (first) matches of its first argument in its second.
Only applicable to single-column categorical/string frames.
:param List table: the list of items to match against
:param int nomatch: value that should be returned when there is no match.
:param boolean indexes: if true return index of searched value,
otherwise return 1 if the value is in the list of items to match against and nomatch value otherwise
:returns: a new H2OFrame containing for each cell from the source frame the index where
the pattern ``table`` first occurs within that cell.
Expand All @@ -4662,7 +4664,7 @@ def match(self, table, nomatch=0):
>>> matchFrame = iris["C5"].match(['Iris-setosa'])
>>> matchFrame
"""
return H2OFrame._expr(expr=ExprNode("match", self, table, nomatch, None))
return H2OFrame._expr(expr=ExprNode("match", self, table, nomatch, None, indexes))

def cut(self, breaks, labels=None, include_lowest=False, right=True, dig_lab=3):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from tests import pyunit_utils
from h2o.utils.typechecks import assert_is_type
from h2o.frame import H2OFrame
from random import randrange
import numpy as np


def h2o_H2OFrame_match():
Expand All @@ -15,17 +13,23 @@ def h2o_H2OFrame_match():
Copied from runit_lstrip.R
"""
iris = h2o.import_file(path=pyunit_utils.locate("smalldata/iris/iris.csv"))
matchFrame = iris["C5"].match(['Iris-setosa'])
assert_is_type(matchFrame, H2OFrame) # check return type
assert matchFrame.sum()[0,0]==50.0, "h2o.H2OFrame.match() command is not working." # check return result

match_frame = iris["C5"].match(['Iris-setosa'])
assert_is_type(match_frame, H2OFrame) # check return type
assert match_frame.sum()[0, 0] == 50.0, "h2o.H2OFrame.match() command is not working." # check return result

matchFrame = iris["C5"].match(['Iris-setosa', 'Iris-versicolor'])
assert_is_type(matchFrame, H2OFrame) # check return type
assert matchFrame.sum()[0,0]==100.0, "h2o.H2OFrame.match() command is not working." # check return result
match_frame = iris["C5"].match(['Iris-setosa', 'Iris-versicolor'])
assert_is_type(match_frame, H2OFrame) # check return type
assert match_frame.sum()[0, 0] == 100.0, "h2o.H2OFrame.match() command is not working." # check return result

matchFrame = iris["C5"].match(['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'])
assert_is_type(matchFrame, H2OFrame) # check return type
assert matchFrame.sum()[0,0]==150.0, "h2o.H2OFrame.match() command is not working." # check return result
match_frame = iris["C5"].match(['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'])
assert_is_type(match_frame, H2OFrame) # check return type
assert match_frame.sum()[0, 0] == 150.0, "h2o.H2OFrame.match() command is not working." # check return result

match_frame = iris["C5"].match(['Iris-setosa'])
assert_is_type(match_frame, H2OFrame) # check return type
assert match_frame.sum()[0, 0] == 50.0, "h2o.H2OFrame.match() command is not working." # check return result



pyunit_utils.standalone_test(h2o_H2OFrame_match)

0 comments on commit 71115a0

Please sign in to comment.