Skip to content

Commit

Permalink
Merge branch 'feat-optimizer_action' of https://github.com/Yihao-Xu/I…
Browse files Browse the repository at this point in the history
…GinX into feat-optimizer_action
  • Loading branch information
Yihao-Xu committed Nov 21, 2024
2 parents abd4867 + 5bcd016 commit db6bdf1
Show file tree
Hide file tree
Showing 19 changed files with 643 additions and 72 deletions.
2 changes: 1 addition & 1 deletion antlr/src/main/antlr4/cn/edu/tsinghua/iginx/sql/Sql.g4
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ fromClause
joinPart
: COMMA tableReference
| CROSS JOIN tableReference
| join tableReference (ON orExpression | USING colList)?
| join tableReference (ON orExpression | USING (KEY | colList))?
;

tableReference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ private Operator initFilterAndMergeFragmentsWithJoin(UnarySelectStatement select
Filter filter = joinCondition.getFilter();
List<String> joinColumns = joinCondition.getJoinColumns();
boolean isNaturalJoin = JoinType.isNaturalJoin(joinCondition.getJoinType());
boolean isJoinByKey = joinCondition.isJoinByKey();

if (!joinColumns.isEmpty() || isNaturalJoin) {
if (prefixA == null || prefixB == null) {
Expand All @@ -507,6 +508,7 @@ private Operator initFilterAndMergeFragmentsWithJoin(UnarySelectStatement select
filter,
joinColumns,
isNaturalJoin,
isJoinByKey,
joinAlgType);
break;
case LeftOuterJoin:
Expand All @@ -527,6 +529,7 @@ private Operator initFilterAndMergeFragmentsWithJoin(UnarySelectStatement select
filter,
joinColumns,
isNaturalJoin,
isJoinByKey,
joinAlgType);
break;
default:
Expand Down Expand Up @@ -727,21 +730,22 @@ private Operator buildAddSequence(UnarySelectStatement selectStatement, Operator
* @return 添加了Reorder操作符的根节点
*/
private static Operator buildReorder(UnarySelectStatement selectStatement, Operator root) {
boolean hasFuncWithArgs =
boolean hasUDFWithArgs =
selectStatement.getExpressions().stream()
.anyMatch(
expression -> {
if (!(expression instanceof FuncExpression)) {
return false;
}
FuncExpression funcExpression = ((FuncExpression) expression);
return !funcExpression.getArgs().isEmpty()
|| !funcExpression.getKvargs().isEmpty();
return funcExpression.isPyUDF()
&& (!funcExpression.getArgs().isEmpty()
|| !funcExpression.getKvargs().isEmpty());
});

if (selectStatement.isLastFirst()) {
root = new Reorder(new OperatorSource(root), Arrays.asList("path", "value"));
} else if (hasFuncWithArgs) {
} else if (hasUDFWithArgs) {
root = new Reorder(new OperatorSource(root), new ArrayList<>(Collections.singletonList("*")));
} else {
List<String> order = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ public static Operator translateApply(Operator root, List<String> correlatedVari
null,
new ArrayList<>(),
false,
false,
JoinAlgType.HashJoin,
correlatedVariables);
} else {
Expand Down Expand Up @@ -311,6 +312,7 @@ private static Operator pushDownApply(Operator root, List<String> correlatedVari
singleJoin.getFilter(),
new ArrayList<>(),
false,
false,
singleJoin.getJoinAlgType(),
singleJoin.getExtraJoinPrefix());
}
Expand Down Expand Up @@ -374,6 +376,7 @@ private static Operator pushDownApply(Operator root, List<String> correlatedVari
new BoolFilter(true),
new ArrayList<>(),
false,
false,
JoinAlgType.HashJoin,
correlatedVariables);
}
Expand Down Expand Up @@ -563,6 +566,7 @@ private static Operator combineAdjacentSelectAndJoin(Select select) {
select.getFilter(),
new ArrayList<>(),
false,
false,
algType,
crossJoin.getExtraJoinPrefix());
case InnerJoin:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import cn.edu.tsinghua.iginx.engine.shared.data.read.Header;
import cn.edu.tsinghua.iginx.engine.shared.data.read.Row;
import cn.edu.tsinghua.iginx.engine.shared.data.read.RowStream;
import cn.edu.tsinghua.iginx.engine.shared.expr.KeyExpression;
import cn.edu.tsinghua.iginx.engine.shared.function.*;
import cn.edu.tsinghua.iginx.engine.shared.function.system.Max;
import cn.edu.tsinghua.iginx.engine.shared.function.system.Min;
Expand Down Expand Up @@ -660,57 +661,7 @@ private RowStream executeJoin(Join join, Table tableA, Table tableB) throws Phys
}
// 目前只支持使用时间戳和顺序
if (join.getJoinBy().equals(Constants.KEY)) {
// 检查时间戳
if (!headerA.hasKey() || !headerB.hasKey()) {
throw new InvalidOperatorParameterException(
"row streams for join operator by key should have key.");
}
List<Field> newFields = new ArrayList<>();
newFields.addAll(headerA.getFields());
newFields.addAll(headerB.getFields());
Header newHeader = new Header(Field.KEY, newFields);
List<Row> newRows = new ArrayList<>();

int index1 = 0, index2 = 0;
while (index1 < tableA.getRowSize() && index2 < tableB.getRowSize()) {
Row rowA = tableA.getRow(index1), rowB = tableB.getRow(index2);
Object[] values = new Object[newHeader.getFieldSize()];
long timestamp;
if (rowA.getKey() == rowB.getKey()) {
timestamp = rowA.getKey();
System.arraycopy(rowA.getValues(), 0, values, 0, headerA.getFieldSize());
System.arraycopy(
rowB.getValues(), 0, values, headerA.getFieldSize(), headerB.getFieldSize());
index1++;
index2++;
} else if (rowA.getKey() < rowB.getKey()) {
timestamp = rowA.getKey();
System.arraycopy(rowA.getValues(), 0, values, 0, headerA.getFieldSize());
index1++;
} else {
timestamp = rowB.getKey();
System.arraycopy(
rowB.getValues(), 0, values, headerA.getFieldSize(), headerB.getFieldSize());
index2++;
}
newRows.add(new Row(newHeader, timestamp, values));
}

for (; index1 < tableA.getRowSize(); index1++) {
Row rowA = tableA.getRow(index1);
Object[] values = new Object[newHeader.getFieldSize()];
System.arraycopy(rowA.getValues(), 0, values, 0, headerA.getFieldSize());
newRows.add(new Row(newHeader, rowA.getKey(), values));
}

for (; index2 < tableB.getRowSize(); index2++) {
Row rowB = tableB.getRow(index2);
Object[] values = new Object[newHeader.getFieldSize()];
System.arraycopy(
rowB.getValues(), 0, values, headerA.getFieldSize(), headerB.getFieldSize());
newRows.add(new Row(newHeader, rowB.getKey(), values));
}
return new Table(newHeader, newRows);
return executeJoinByKey(tableA, tableB, true, true);
} else if (join.getJoinBy().equals(Constants.ORDINAL)) {
if (headerA.hasKey() || headerB.hasKey()) {
throw new InvalidOperatorParameterException(
Expand Down Expand Up @@ -778,6 +729,17 @@ private RowStream executeCrossJoin(CrossJoin crossJoin, Table tableA, Table tabl

private RowStream executeInnerJoin(InnerJoin innerJoin, Table tableA, Table tableB)
throws PhysicalException {
if (innerJoin.isJoinByKey()) {
Sort sortByKey =
new Sort(
EmptySource.EMPTY_SOURCE,
Collections.singletonList(new KeyExpression(KEY)),
Collections.singletonList(Sort.SortType.ASC));
tableA = transformToTable(executeSort(sortByKey, tableA));
tableB = transformToTable(executeSort(sortByKey, tableB));
return executeJoinByKey(tableA, tableB, false, false);
}

switch (innerJoin.getJoinAlgType()) {
case NestedLoopJoin:
return executeNestedLoopInnerJoin(innerJoin, tableA, tableB);
Expand All @@ -790,6 +752,76 @@ private RowStream executeInnerJoin(InnerJoin innerJoin, Table tableA, Table tabl
}
}

private RowStream executeJoinByKey(Table tableA, Table tableB, boolean isLeft, boolean isRight)
throws PhysicalException {
Header headerA = tableA.getHeader();
Header headerB = tableB.getHeader();
// 检查时间戳
if (!headerA.hasKey() || !headerB.hasKey()) {
throw new InvalidOperatorParameterException(
"row streams for join operator by key should have key.");
}
List<Field> newFields = new ArrayList<>();
newFields.addAll(headerA.getFields());
newFields.addAll(headerB.getFields());
Header newHeader = new Header(Field.KEY, newFields);
List<Row> newRows = new ArrayList<>();

int index1 = 0, index2 = 0;
while (index1 < tableA.getRowSize() && index2 < tableB.getRowSize()) {
Row rowA = tableA.getRow(index1), rowB = tableB.getRow(index2);
Object[] values = new Object[newHeader.getFieldSize()];
long timestamp;
if (rowA.getKey() == rowB.getKey()) {
timestamp = rowA.getKey();
System.arraycopy(rowA.getValues(), 0, values, 0, headerA.getFieldSize());
System.arraycopy(
rowB.getValues(), 0, values, headerA.getFieldSize(), headerB.getFieldSize());
index1++;
index2++;
} else if (rowA.getKey() < rowB.getKey()) {
index1++;
if (!isLeft) { // 内连接和右连接不保留该结果
continue;
}
timestamp = rowA.getKey();
System.arraycopy(rowA.getValues(), 0, values, 0, headerA.getFieldSize());
} else {
index2++;
if (!isRight) { // 内连接和左连接不保留该结果
continue;
}
timestamp = rowB.getKey();
System.arraycopy(
rowB.getValues(), 0, values, headerA.getFieldSize(), headerB.getFieldSize());
}
newRows.add(new Row(newHeader, timestamp, values));
}

// 左连接和全连接才保留该结果
if (isLeft) {
for (; index1 < tableA.getRowSize(); index1++) {
Row rowA = tableA.getRow(index1);
Object[] values = new Object[newHeader.getFieldSize()];
System.arraycopy(rowA.getValues(), 0, values, 0, headerA.getFieldSize());
newRows.add(new Row(newHeader, rowA.getKey(), values));
}
}

// 右连接和全连接才保留该结果
if (isRight) {
for (; index2 < tableB.getRowSize(); index2++) {
Row rowB = tableB.getRow(index2);
Object[] values = new Object[newHeader.getFieldSize()];
System.arraycopy(
rowB.getValues(), 0, values, headerA.getFieldSize(), headerB.getFieldSize());
newRows.add(new Row(newHeader, rowB.getKey(), values));
}
}

return new Table(newHeader, newRows);
}

private RowStream executeNestedLoopInnerJoin(InnerJoin innerJoin, Table tableA, Table tableB)
throws PhysicalException {
List<String> joinColumns = new ArrayList<>(innerJoin.getJoinColumns());
Expand Down Expand Up @@ -1187,6 +1219,19 @@ private RowStream executeSortedMergeInnerJoin(InnerJoin innerJoin, Table tableA,

private RowStream executeOuterJoin(OuterJoin outerJoin, Table tableA, Table tableB)
throws PhysicalException {
if (outerJoin.isJoinByKey()) {
Sort sortByKey =
new Sort(
EmptySource.EMPTY_SOURCE,
Collections.singletonList(new KeyExpression(KEY)),
Collections.singletonList(Sort.SortType.ASC));
tableA = transformToTable(executeSort(sortByKey, tableA));
tableB = transformToTable(executeSort(sortByKey, tableB));
boolean isLeft = outerJoin.getOuterJoinType() != OuterJoinType.RIGHT;
boolean isRight = outerJoin.getOuterJoinType() != OuterJoinType.LEFT;
return executeJoinByKey(tableA, tableB, isLeft, isRight);
}

switch (outerJoin.getJoinAlgType()) {
case NestedLoopJoin:
return executeNestedLoopOuterJoin(outerJoin, tableA, tableB);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public class FunctionUtils {
private static final String VALUE = "value";

private static final Set<String> sysRowToRowFunctionSet =
new HashSet<>(Arrays.asList("ratio", "substring"));
new HashSet<>(Arrays.asList("extract", "ratio", "substring"));

private static final Set<String> sysSetToRowFunctionSet =
new HashSet<>(
Expand Down Expand Up @@ -165,7 +165,6 @@ public static String getFunctionName(Function function) {

static Map<String, Integer> expectedParamNumMap = new HashMap<>(); // 此Map用于存储function期望的参数个数

// TODO
static {
expectedParamNumMap.put("avg", 1);
expectedParamNumMap.put("sum", 1);
Expand All @@ -176,6 +175,7 @@ public static String getFunctionName(Function function) {
expectedParamNumMap.put("last_value", 1);
expectedParamNumMap.put("first", 1);
expectedParamNumMap.put("last", 1);
expectedParamNumMap.put("extract", 1);
expectedParamNumMap.put("ratio", 2);
expectedParamNumMap.put("substring", 1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import cn.edu.tsinghua.iginx.engine.shared.function.system.ArithmeticExpr;
import cn.edu.tsinghua.iginx.engine.shared.function.system.Avg;
import cn.edu.tsinghua.iginx.engine.shared.function.system.Count;
import cn.edu.tsinghua.iginx.engine.shared.function.system.Extract;
import cn.edu.tsinghua.iginx.engine.shared.function.system.First;
import cn.edu.tsinghua.iginx.engine.shared.function.system.FirstValue;
import cn.edu.tsinghua.iginx.engine.shared.function.system.Last;
Expand Down Expand Up @@ -100,6 +101,7 @@ private void initSystemFunctions() {
registerFunction(Min.getInstance());
registerFunction(Sum.getInstance());
registerFunction(ArithmeticExpr.getInstance());
registerFunction(Extract.getInstance());
registerFunction(Ratio.getInstance());
registerFunction(SubString.getInstance());
}
Expand Down
Loading

0 comments on commit db6bdf1

Please sign in to comment.