Skip to content

Commit

Permalink
feat(sql): GROUP BY expr && ORDER BY expr (#465)
Browse files Browse the repository at this point in the history
* feat(sql): GROUP BY expr && ORDER BY expr

1.支持对GROUP BY和ORDER BY中的列使用RowToRow表达式
2.支持GROUP BY和ORDER BY中的列与SELECT子句中的别名进行匹配


Co-authored-by: Yuqing Zhu <[email protected]>
  • Loading branch information
jzl18thu and zhuyuqing authored Oct 22, 2024
1 parent 67e6616 commit aff9e62
Show file tree
Hide file tree
Showing 35 changed files with 949 additions and 203 deletions.
13 changes: 9 additions & 4 deletions antlr/src/main/antlr4/cn/edu/tsinghua/iginx/sql/Sql.g4
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,11 @@ functionName
;

caseSpecification
: simipleCase
: simpleCase
| searchedCase
;

simipleCase
simpleCase
: CASE expression simpleWhenClause (simpleWhenClause)* elseClause? END
;

Expand Down Expand Up @@ -303,7 +303,12 @@ specialClause
;

groupByClause
: GROUP BY path (COMMA path)*
: GROUP BY groupByItem (COMMA groupByItem)*
;

groupByItem
: path
| expression
;

havingClause
Expand All @@ -315,7 +320,7 @@ orderByClause
;

orderItem
: path (DESC | ASC)?
: (path | expression) (DESC | ASC)?
;

downsampleClause
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -583,15 +583,15 @@ private static Operator buildLimit(SelectStatement selectStatement, Operator roo
* @return 添加了Sort操作符的根节点;如果没有Order By子句,返回原根节点
*/
private static Operator buildOrderByPaths(SelectStatement selectStatement, Operator root) {
if (selectStatement.getOrderByPaths().isEmpty()) {
if (selectStatement.getOrderByExpressions().isEmpty()) {
return root;
}
List<Sort.SortType> sortTypes = new ArrayList<>();
selectStatement
.getAscendingList()
.forEach(
isAscending -> sortTypes.add(isAscending ? Sort.SortType.ASC : Sort.SortType.DESC));
return new Sort(new OperatorSource(root), selectStatement.getOrderByPaths(), sortTypes);
return new Sort(new OperatorSource(root), selectStatement.getOrderByExpressions(), sortTypes);
}

/**
Expand Down Expand Up @@ -662,7 +662,7 @@ private Operator buildGroupByQuery(UnarySelectStatement selectStatement, Operato
List<FunctionCall> functionCallList =
getFunctionCallList(selectStatement, MappingType.SetMapping);
return new GroupBy(
new OperatorSource(root), selectStatement.getGroupByPaths(), functionCallList);
new OperatorSource(root), selectStatement.getGroupByExpressions(), functionCallList);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import static cn.edu.tsinghua.iginx.engine.shared.operator.type.OperatorType.isUnaryOperator;

import cn.edu.tsinghua.iginx.engine.shared.expr.BaseExpression;
import cn.edu.tsinghua.iginx.engine.shared.expr.Expression;
import cn.edu.tsinghua.iginx.engine.shared.function.FunctionCall;
import cn.edu.tsinghua.iginx.engine.shared.function.FunctionParams;
import cn.edu.tsinghua.iginx.engine.shared.function.FunctionUtils;
Expand Down Expand Up @@ -316,18 +317,19 @@ private static Operator pushDownApply(Operator root, List<String> correlatedVari
root =
new GroupBy(
new OperatorSource(pushDownApply(apply, correlatedVariables)),
correlatedVariables,
correlatedVariables.stream().map(BaseExpression::new).collect(Collectors.toList()),
setTransform.getFunctionCallList());
break;
case GroupBy:
GroupBy groupBy = (GroupBy) operatorB;
apply.setSourceB(groupBy.getSource());
List<String> groupByCols = groupBy.getGroupByCols();
groupByCols.addAll(correlatedVariables);
List<Expression> groupByExpressions = groupBy.getGroupByExpressions();
groupByExpressions.addAll(
correlatedVariables.stream().map(BaseExpression::new).collect(Collectors.toList()));
root =
new GroupBy(
new OperatorSource(pushDownApply(apply, correlatedVariables)),
groupByCols,
groupByExpressions,
groupBy.getFunctionCallList());
break;
case Rename:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,11 @@ private RowStream executeSelect(Select select, Table table) throws PhysicalExcep
}

private RowStream executeSort(Sort sort, Table table) throws PhysicalException {
RowTransform preRowTransform = HeaderUtils.checkSortHeader(table.getHeader(), sort);
if (preRowTransform != null) {
table = transformToTable(executeRowTransform(preRowTransform, table));
}

List<Boolean> ascendingList = sort.getAscendingList();
RowUtils.sortRows(table.getRows(), ascendingList, sort.getSortByCols());
return table;
Expand Down Expand Up @@ -483,6 +488,11 @@ private RowStream executeAddSchemaPrefix(AddSchemaPrefix addSchemaPrefix, Table
}

private RowStream executeGroupBy(GroupBy groupBy, Table table) throws PhysicalException {
RowTransform preRowTransform = HeaderUtils.checkGroupByHeader(table.getHeader(), groupBy);
if (preRowTransform != null) {
table = transformToTable(executeRowTransform(preRowTransform, table));
}

List<Row> rows = RowUtils.cacheGroupByResult(groupBy, table);
if (rows.isEmpty()) {
return Table.EMPTY_TABLE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import cn.edu.tsinghua.iginx.engine.physical.exception.UnexpectedOperatorException;
import cn.edu.tsinghua.iginx.engine.physical.memory.execute.OperatorMemoryExecutor;
import cn.edu.tsinghua.iginx.engine.physical.memory.execute.Table;
import cn.edu.tsinghua.iginx.engine.physical.memory.execute.utils.HeaderUtils;
import cn.edu.tsinghua.iginx.engine.physical.memory.execute.utils.RowUtils;
import cn.edu.tsinghua.iginx.engine.shared.Constants;
import cn.edu.tsinghua.iginx.engine.shared.RequestContext;
Expand Down Expand Up @@ -198,6 +199,11 @@ private RowStream executeSelect(Select select, RowStream stream) {
}

private RowStream executeSort(Sort sort, RowStream stream) throws PhysicalException {
RowTransform preRowTransform = HeaderUtils.checkSortHeader(stream.getHeader(), sort);
if (preRowTransform != null) {
stream = executeRowTransform(preRowTransform, stream);
}

return new SortLazyStream(sort, stream);
}

Expand Down Expand Up @@ -270,7 +276,12 @@ private RowStream executeAddSchemaPrefix(AddSchemaPrefix addSchemaPrefix, RowStr
return new AddSchemaPrefixLazyStream(addSchemaPrefix, stream);
}

private RowStream executeGroupBy(GroupBy groupBy, RowStream stream) {
private RowStream executeGroupBy(GroupBy groupBy, RowStream stream) throws PhysicalException {
RowTransform preRowTransform = HeaderUtils.checkGroupByHeader(stream.getHeader(), groupBy);
if (preRowTransform != null) {
stream = executeRowTransform(preRowTransform, stream);
}

return new GroupByLazyStream(groupBy, stream);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
*/
package cn.edu.tsinghua.iginx.engine.physical.memory.execute.utils;

import static cn.edu.tsinghua.iginx.engine.shared.function.system.ArithmeticExpr.ARITHMETIC_EXPR;
import static cn.edu.tsinghua.iginx.engine.shared.function.system.utils.ValueUtils.isNumericType;
import static cn.edu.tsinghua.iginx.sql.SQLConstant.DOT;
import static cn.edu.tsinghua.iginx.thrift.DataType.BOOLEAN;
Expand All @@ -28,13 +29,26 @@
import cn.edu.tsinghua.iginx.engine.physical.exception.PhysicalException;
import cn.edu.tsinghua.iginx.engine.shared.data.read.Field;
import cn.edu.tsinghua.iginx.engine.shared.data.read.Header;
import cn.edu.tsinghua.iginx.engine.shared.expr.BaseExpression;
import cn.edu.tsinghua.iginx.engine.shared.expr.Expression;
import cn.edu.tsinghua.iginx.engine.shared.expr.KeyExpression;
import cn.edu.tsinghua.iginx.engine.shared.function.Function;
import cn.edu.tsinghua.iginx.engine.shared.function.FunctionCall;
import cn.edu.tsinghua.iginx.engine.shared.function.FunctionParams;
import cn.edu.tsinghua.iginx.engine.shared.function.manager.FunctionManager;
import cn.edu.tsinghua.iginx.engine.shared.operator.GroupBy;
import cn.edu.tsinghua.iginx.engine.shared.operator.RowTransform;
import cn.edu.tsinghua.iginx.engine.shared.operator.Sort;
import cn.edu.tsinghua.iginx.engine.shared.operator.filter.*;
import cn.edu.tsinghua.iginx.engine.shared.source.EmptySource;
import cn.edu.tsinghua.iginx.thrift.DataType;
import cn.edu.tsinghua.iginx.utils.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;

public class HeaderUtils {

Expand Down Expand Up @@ -330,4 +344,55 @@ public static void checkHeadersComparable(Header headerA, Header headerB)
}
}
}

public static RowTransform checkGroupByHeader(Header header, GroupBy groupBy) {
Set<Expression> appendExpressions = new HashSet<>();
for (Expression groupByExpr : groupBy.getGroupByExpressions()) {
String exprName = groupByExpr.getColumnName();
boolean found =
header.getFields().stream().anyMatch(field -> field.getName().equals(exprName));
if (!found) {
appendExpressions.add(groupByExpr);
}
}

if (appendExpressions.isEmpty()) {
return null;
}
return appendArithExpressions(header, new ArrayList<>(appendExpressions));
}

public static RowTransform checkSortHeader(Header header, Sort sort) {
List<Expression> sortExpressions = new ArrayList<>(sort.getSortByExpressions());
if (sortExpressions.get(0) instanceof KeyExpression) {
sortExpressions.remove(0);
}
Set<Expression> appendExpressions = new HashSet<>();
for (Expression sortExpr : sortExpressions) {
String exprName = sortExpr.getColumnName();
boolean found =
header.getFields().stream().anyMatch(field -> field.getName().equals(exprName));
if (!found) {
appendExpressions.add(sortExpr);
}
}

if (appendExpressions.isEmpty()) {
return null;
}
return appendArithExpressions(header, new ArrayList<>(appendExpressions));
}

private static RowTransform appendArithExpressions(Header header, List<Expression> expressions) {
List<FunctionCall> functionCallList = new ArrayList<>();
Function function = FunctionManager.getInstance().getFunction(ARITHMETIC_EXPR);
for (Field field : header.getFields()) {
functionCallList.add(
new FunctionCall(function, new FunctionParams(new BaseExpression(field.getName()))));
}
for (Expression expr : expressions) {
functionCallList.add(new FunctionCall(function, new FunctionParams(expr)));
}
return new RowTransform(EmptySource.EMPTY_SOURCE, functionCallList);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,16 @@ public void setAlias(String alias) {
public void accept(ExpressionVisitor visitor) {
visitor.visit(this);
}

@Override
public boolean equalExceptAlias(Expression expr) {
if (this == expr) {
return true;
}
if (expr == null || expr.getType() != ExpressionType.Base) {
return false;
}
BaseExpression that = (BaseExpression) expr;
return this.pathName.equals(that.pathName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,18 @@ public void accept(ExpressionVisitor visitor) {
leftExpression.accept(visitor);
rightExpression.accept(visitor);
}

@Override
public boolean equalExceptAlias(Expression expr) {
if (this == expr) {
return true;
}
if (expr == null || expr.getType() != ExpressionType.Binary) {
return false;
}
BinaryExpression that = (BinaryExpression) expr;
return this.leftExpression.equalExceptAlias(that.leftExpression)
&& this.rightExpression.equalExceptAlias(that.rightExpression)
&& this.op == that.op;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,16 @@ public void accept(ExpressionVisitor visitor) {
visitor.visit(this);
expression.accept(visitor);
}

@Override
public boolean equalExceptAlias(Expression expr) {
if (this == expr) {
return true;
}
if (expr == null || expr.getType() != ExpressionType.Bracket) {
return false;
}
BracketExpression that = (BracketExpression) expr;
return this.expression.equalExceptAlias(that.expression);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,38 @@ public void accept(ExpressionVisitor visitor) {
}
resultElse.accept(visitor);
}

@Override
public boolean equalExceptAlias(Expression expr) {
if (this == expr) {
return true;
}
if (expr == null || expr.getType() != ExpressionType.CaseWhen) {
return false;
}
CaseWhenExpression that = (CaseWhenExpression) expr;
if (this.conditions.size() != that.conditions.size()) {
return false;
}
for (int i = 0; i < this.conditions.size(); i++) {
if (!this.conditions.get(i).equals(that.conditions.get(i))) {
return false;
}
}
if (this.results.size() != that.results.size()) {
return false;
}
for (int i = 0; i < this.results.size(); i++) {
if (!this.results.get(i).equalExceptAlias(that.results.get(i))) {
return false;
}
}
if (this.resultElse == null && that.resultElse == null) {
return true;
}
if (this.resultElse == null || that.resultElse == null) {
return false;
}
return this.resultElse.equalExceptAlias(that.resultElse);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
*/
package cn.edu.tsinghua.iginx.engine.shared.expr;

import java.util.Arrays;

public class ConstantExpression implements Expression {

private final Object value;
Expand Down Expand Up @@ -70,4 +72,20 @@ public void setAlias(String alias) {
public void accept(ExpressionVisitor visitor) {
visitor.visit(this);
}

@Override
public boolean equalExceptAlias(Expression expr) {
if (this == expr) {
return true;
}
if (expr == null || expr.getType() != ExpressionType.Constant) {
return false;
}
ConstantExpression that = (ConstantExpression) expr;
if (value instanceof byte[]) {
return Arrays.equals((byte[]) value, (byte[]) that.value);
} else {
return this.value.equals(that.value);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ public interface Expression {

void accept(ExpressionVisitor visitor);

boolean equalExceptAlias(Expression expr);

enum ExpressionType {
Bracket,
Binary,
Expand Down
Loading

0 comments on commit aff9e62

Please sign in to comment.