Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

901 inner join and semi join with result cardinality hint #918

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/ir/daphneir/DaphneOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1076,7 +1076,7 @@ def Daphne_InnerJoinOp : Daphne_Op<"innerJoin", [
DataTypeFrm, ValueTypesConcat,
DeclareOpInterfaceMethods<InferFrameLabelsOpInterface>,
]> {
let arguments = (ins FrameOrU:$lhs, FrameOrU:$rhs, StrScalar:$lhsOn, StrScalar:$rhsOn);
let arguments = (ins FrameOrU:$lhs, FrameOrU:$rhs, StrScalar:$lhsOn, StrScalar:$rhsOn, Size:$numRowRes);
let results = (outs FrameOrU:$res);
}

Expand Down Expand Up @@ -1120,7 +1120,7 @@ def Daphne_SemiJoinOp : Daphne_Op<"semiJoin", [
DeclareOpInterfaceMethods<InferTypesOpInterface>,
NumColsFromArg
]> {
let arguments = (ins FrameOrU:$lhs, FrameOrU:$rhs, StrScalar:$lhsOn, StrScalar:$rhsOn);
let arguments = (ins FrameOrU:$lhs, FrameOrU:$rhs, StrScalar:$lhsOn, StrScalar:$rhsOn, Size:$numRowRes);
let results = (outs FrameOrU:$res, MatrixOf<[Size]>:$lhsTids);
}

Expand Down
18 changes: 14 additions & 4 deletions src/parser/daphnedsl/DaphneDSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -993,13 +993,18 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string &fu
builder.create<CartesianOp>(loc, FrameType::get(builder.getContext(), colTypes), args[0], args[1]));
}
if (func == "innerJoin") {
checkNumArgsExact(loc, func, numArgs, 4);
checkNumArgsMin(loc, func, numArgs, 4);
std::vector<mlir::Type> colTypes;
mlir::Value numRowRes;
for (int i = 0; i < 2; i++)
for (mlir::Type t : args[i].getType().dyn_cast<FrameType>().getColumnTypes())
colTypes.push_back(t);
if (numArgs == 5)
numRowRes = utils.castSI64If(args[4]);
else
numRowRes = builder.create<ConstantOp>(loc, int64_t(-1));
return static_cast<mlir::Value>(builder.create<InnerJoinOp>(loc, FrameType::get(builder.getContext(), colTypes),
args[0], args[1], args[2], args[3]));
args[0], args[1], args[2], args[3], numRowRes));
}
if (func == "fullOuterJoin")
return createJoinOp<FullOuterJoinOp>(loc, func, args);
Expand All @@ -1011,14 +1016,19 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string &fu
// TODO Reconcile this with the other join ops, but we need it to work
// quickly now.
// return createJoinOp<SemiJoinOp>(loc, func, args);
checkNumArgsExact(loc, func, numArgs, 4);
checkNumArgsMin(loc, func, numArgs, 4);
mlir::Value lhs = args[0];
mlir::Value rhs = args[1];
mlir::Value lhsOn = args[2];
mlir::Value rhsOn = args[3];
mlir::Value numRowRes;
if (numArgs == 5)
numRowRes = utils.castSI64If(args[4]);
else
numRowRes = builder.create<ConstantOp>(loc, int64_t(-1));
return builder
.create<SemiJoinOp>(loc, FrameType::get(builder.getContext(), {utils.unknownType}), utils.matrixOfSizeType,
lhs, rhs, lhsOn, rhsOn)
lhs, rhs, lhsOn, rhsOn, numRowRes)
.getResults();
}
if (func == "groupJoin") {
Expand Down
5 changes: 4 additions & 1 deletion src/parser/sql/SQLVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,8 +551,11 @@ antlrcpp::Any SQLVisitor::visitInnerJoin(SQLGrammarParser::InnerJoinContext *ctx
mlir::Value rhsName = valueOrErrorOnVisit(ctx->rhs);
mlir::Value lhsName = valueOrErrorOnVisit(ctx->lhs);

mlir::Value numRowRes =
static_cast<mlir::Value>(builder.create<mlir::daphne::ConstantOp>(queryLoc, static_cast<int64_t>(-1)));

return static_cast<mlir::Value>(
builder.create<mlir::daphne::InnerJoinOp>(loc, t, currentFrame, tojoin, rhsName, lhsName));
builder.create<mlir::daphne::InnerJoinOp>(loc, t, currentFrame, tojoin, rhsName, lhsName, numRowRes));
}

std::vector<mlir::Value> rhsNames;
Expand Down
5 changes: 4 additions & 1 deletion src/runtime/local/kernels/CastObj.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ template <typename VTRes> class CastObj<DenseMatrix<VTRes>, Frame> {
const size_t numRows = argFrm->getNumRows();
const DenseMatrix<VTArg> *argCol = argFrm->getColumn<VTArg>(c);
for (size_t r = 0; r < numRows; r++)
res->set(r, c, static_cast<VTRes>(argCol->get(r, 0)));
res->set(r, c, castSca<VTRes, VTArg>(argCol->get(r, 0), nullptr));
DataObjectFactory::destroy(argCol);
}

Expand Down Expand Up @@ -126,6 +126,9 @@ template <typename VTRes> class CastObj<DenseMatrix<VTRes>, Frame> {
case ValueTypeCode::UI8:
castCol<uint8_t>(res, arg, c);
break;
case ValueTypeCode::STR:
castCol<std::string>(res, arg, c);
break;
default:
throw std::runtime_error("CastObj::apply: unknown value type code");
}
Expand Down
7 changes: 7 additions & 0 deletions src/runtime/local/kernels/CastSca.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,11 @@ template <typename VTRes> struct CastSca<VTRes, FixedStr16> {
}
};

// ----------------------------------------------------------------------------
// string <- string
// ----------------------------------------------------------------------------
template <> struct CastSca<std::string, std::string> {
static std::string apply(const std::string &arg, DaphneContext *ctx) { return arg; }
};

#endif // SRC_RUNTIME_LOCAL_KERNELS_CASTSCA_H
2 changes: 1 addition & 1 deletion src/runtime/local/kernels/EwBinaryObjSca.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ struct EwBinaryObjSca<DenseMatrix<VTRes>, DenseMatrix<VTLhs>, VTRhs> {
if (res == nullptr)
res = DataObjectFactory::create<DenseMatrix<VTRes>>(numRows, numCols, false);

const VTRes *valuesLhs = lhs->getValues();
const VTLhs *valuesLhs = lhs->getValues();
VTRes *valuesRes = res->getValues();

EwBinaryScaFuncPtr<VTRes, VTLhs, VTRhs> func = getEwBinaryScaFuncPtr<VTRes, VTLhs, VTRhs>(opCode);
Expand Down
13 changes: 12 additions & 1 deletion src/runtime/local/kernels/InnerJoin.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,19 @@ inline void innerJoin(
const Frame *lhs, const Frame *rhs,
// input column names
const char *lhsOn, const char *rhsOn,
// result size
int64_t numRowRes,
// context
DCTX(ctx)) {

// Find out the value types of the columns to process.
ValueTypeCode vtcLhsOn = lhs->getColumnType(lhsOn);
ValueTypeCode vtcRhsOn = rhs->getColumnType(rhsOn);

// Perhaps check if res already allocated.
const size_t numRowRhs = rhs->getNumRows();
const size_t numRowLhs = lhs->getNumRows();
const size_t totalRows = numRowRhs * numRowLhs;
const size_t totalRows = numRowRes == -1 ? numRowRhs * numRowLhs : numRowRes;
const size_t numColRhs = rhs->getNumCols();
const size_t numColLhs = lhs->getNumCols();
const size_t totalCols = numColRhs + numColLhs;
Expand Down Expand Up @@ -126,15 +129,23 @@ inline void innerJoin(
row_idx_r, ctx);
hit = hit || innerJoinProbeIf<double, double>(vtcLhsOn, vtcRhsOn, res, lhs, rhs, lhsOn, rhsOn, row_idx_l,
row_idx_r, ctx);
hit = hit || innerJoinProbeIf<std::string, std::string>(vtcLhsOn, vtcRhsOn, res, lhs, rhs, lhsOn, rhsOn,
row_idx_l, row_idx_r, ctx);
if (hit) {
for (size_t idx_c = 0; idx_c < numColLhs; idx_c++) {
innerJoinSet<std::string>(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c,
ctx);
innerJoinSet<int64_t>(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c,
ctx);
innerJoinSet<double>(schema[col_idx_res], res, lhs, row_idx_res, col_idx_res, row_idx_l, idx_c,
ctx);
col_idx_res++;
}
for (size_t idx_c = 0; idx_c < numColRhs; idx_c++) {

innerJoinSet<std::string>(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c,
ctx);

innerJoinSet<int64_t>(schema[col_idx_res], res, rhs, row_idx_res, col_idx_res, row_idx_r, idx_c,
ctx);

Expand Down
21 changes: 15 additions & 6 deletions src/runtime/local/kernels/SemiJoin.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ void semiJoinCol(
Frame *&res, DenseMatrix<VTTid> *&resLhsTid,
// arguments
const DenseMatrix<VTLhs> *argLhs, const DenseMatrix<VTRhs> *argRhs,
// result size
int64_t numRowRes,
// context
DCTX(ctx)) {
if (argLhs->getNumCols() != 1)
Expand All @@ -72,11 +74,14 @@ void semiJoinCol(
// Create the output data objects.
if (res == nullptr) {
ValueTypeCode schema[] = {ValueTypeUtils::codeFor<VTLhs>};
res = DataObjectFactory::create<Frame>(numArgLhs, 1, schema, nullptr, false);
const size_t resSize = numRowRes == -1 ? numArgLhs : numRowRes;
res = DataObjectFactory::create<Frame>(resSize, 1, schema, nullptr, false);
}
auto resLhs = res->getColumn<VTLhs>(0);
if (resLhsTid == nullptr)
resLhsTid = DataObjectFactory::create<DenseMatrix<VTTid>>(numArgLhs, 1, false);
if (resLhsTid == nullptr) {
const size_t resLhsTidSize = numRowRes == -1 ? numArgLhs : numRowRes;
resLhsTid = DataObjectFactory::create<DenseMatrix<VTTid>>(resLhsTidSize, 1, false);
}

size_t pos = 0;
for (size_t i = 0; i < numArgLhs; i++) {
Expand Down Expand Up @@ -107,11 +112,13 @@ void semiJoinColIf(
const Frame *lhs, const Frame *rhs,
// input column names
const char *lhsOn, const char *rhsOn,
// result size
int64_t numRowRes,
// context
DCTX(ctx)) {
if (vtcLhs == ValueTypeUtils::codeFor<VTLhs> && vtcRhs == ValueTypeUtils::codeFor<VTRhs>) {
semiJoinCol<VTLhs, VTRhs, VTTid>(res, resLhsTid, lhs->getColumn<VTLhs>(lhsOn), rhs->getColumn<VTRhs>(rhsOn),
ctx);
numRowRes, ctx);
}
}

Expand All @@ -127,6 +134,8 @@ void semiJoin(
const Frame *lhs, const Frame *rhs,
// input column names
const char *lhsOn, const char *rhsOn,
// result size
int64_t numRowRes,
// context
DCTX(ctx)) {
// Find out the value types of the columns to process.
Expand All @@ -136,8 +145,8 @@ void semiJoin(
// Call the semiJoin-kernel on columns for the actual combination of
// value types.
// Repeat this for all type combinations...
semiJoinColIf<int64_t, int64_t, VTLhsTid>(vtcLhsOn, vtcRhsOn, res, lhsTid, lhs, rhs, lhsOn, rhsOn, ctx);
semiJoinColIf<int64_t, int64_t, VTLhsTid>(vtcLhsOn, vtcRhsOn, res, lhsTid, lhs, rhs, lhsOn, rhsOn, ctx);
semiJoinColIf<int64_t, int64_t, VTLhsTid>(vtcLhsOn, vtcRhsOn, res, lhsTid, lhs, rhs, lhsOn, rhsOn, numRowRes, ctx);
semiJoinColIf<int64_t, int64_t, VTLhsTid>(vtcLhsOn, vtcRhsOn, res, lhsTid, lhs, rhs, lhsOn, rhsOn, numRowRes, ctx);

// Set the column labels of the result frame.
std::string labels[] = {lhsOn};
Expand Down
10 changes: 9 additions & 1 deletion src/runtime/local/kernels/kernels.json
Original file line number Diff line number Diff line change
Expand Up @@ -2974,7 +2974,11 @@
{
"type": "const char *",
"name": "rhsOn"
}
},
{
"type": "int64_t",
"name": "numRowRes"
}
]
},
"instantiations": [[]]
Expand Down Expand Up @@ -4267,6 +4271,10 @@
{
"type": "const char *",
"name": "rhsOn"
},
{
"type": "int64_t",
"name": "numRowRes"
}
]
},
Expand Down
2 changes: 2 additions & 0 deletions test/api/cli/operations/OperationsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ MAKE_TEST_CASE("fill", 1)
MAKE_TEST_CASE("gemv", 1)
MAKE_TEST_CASE("idxMax", 1)
MAKE_TEST_CASE("idxMin", 1)
MAKE_TEST_CASE("innerJoin", 1)
MAKE_TEST_CASE("isNan", 1)
MAKE_TEST_CASE("lower", 1)
MAKE_TEST_CASE("mean", 1)
Expand All @@ -59,6 +60,7 @@ MAKE_TEST_CASE("rbind", 1)
MAKE_TEST_CASE("recode", 4)
MAKE_TEST_CASE("replace", 1)
MAKE_TEST_CASE("reverse", 1)
MAKE_TEST_CASE("semiJoin", 1)
MAKE_TEST_CASE("seq", 2)
MAKE_TEST_CASE("solve", 1)
MAKE_TEST_CASE("sqrt", 1)
Expand Down
14 changes: 14 additions & 0 deletions test/api/cli/operations/innerJoin_1.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# test inner join with optional arg for result size
f1 = createFrame(
[1, 2], [3, 4],
"a", "b"
);
f2 = createFrame(
[3, 4, 5], [6, 7, 8],
"c", "d"
);

f3 = innerJoin(f1, f2, "b", "c");
f4 = innerJoin(f1, f2, "b", "c", 2);
print(f3);
print(f4);
6 changes: 6 additions & 0 deletions test/api/cli/operations/innerJoin_1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Frame(2x4, [a:int64_t, b:int64_t, c:int64_t, d:int64_t])
1 3 3 6
2 4 4 7
Frame(2x4, [a:int64_t, b:int64_t, c:int64_t, d:int64_t])
1 3 3 6
2 4 4 7
8 changes: 8 additions & 0 deletions test/api/cli/operations/semiJoin_1.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#test inner join with optional arg for result size
f1 = createFrame([ 1, 2 ], [ 3, 4 ], "a", "b");
f2 = createFrame([ 3, 4, 5 ], [ 6, 7, 8 ], "c", "d");

keys1, tids1 = semiJoin(f1, f2, "b", "c");
keys2, tids2 = semiJoin(f1, f2, "b", "c", 2);
print(f1[tids1, ]);
print(f1[tids2, ]);
6 changes: 6 additions & 0 deletions test/api/cli/operations/semiJoin_1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Frame(2x2, [a:int64_t, b:int64_t])
1 3
2 4
Frame(2x2, [a:int64_t, b:int64_t])
1 3
2 4
4 changes: 2 additions & 2 deletions test/runtime/local/kernels/InnerJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

#include <cstdint>

TEST_CASE("innerJoin", TAG_KERNELS) {
TEST_CASE("InnerJoin", TAG_KERNELS) {
auto lhsC0 = genGivenVals<DenseMatrix<int64_t>>(4, {1, 2, 3, 4});
auto lhsC1 = genGivenVals<DenseMatrix<double>>(4, {11.0, 22.0, 33.0, 44.00});
std::vector<Structure *> lhsCols = {lhsC0, lhsC1};
Expand All @@ -46,7 +46,7 @@ TEST_CASE("innerJoin", TAG_KERNELS) {
auto rhs = DataObjectFactory::create<Frame>(rhsCols, rhsLabels);

Frame *res = nullptr;
innerJoin(res, lhs, rhs, "a", "c", nullptr);
innerJoin(res, lhs, rhs, "a", "c", -1, nullptr);

// Check the meta data.
CHECK(res->getNumRows() == 2);
Expand Down
2 changes: 1 addition & 1 deletion test/runtime/local/kernels/SemiJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ TEST_CASE("SemiJoin", TAG_KERNELS) {
// res
Frame *res = nullptr;
DenseMatrix<int64_t> *lhsTid = nullptr;
semiJoin(res, lhsTid, lhs, rhs, "a", "c", nullptr);
semiJoin(res, lhsTid, lhs, rhs, "a", "c", -1, nullptr);

CHECK(*res == *expRes);
CHECK(*lhsTid == *expTid);
Expand Down