From 82f5193953103f31b60da14983a3a2944cfb3cea Mon Sep 17 00:00:00 2001 From: YANGDB Date: Tue, 29 Aug 2023 17:11:30 -0700 Subject: [PATCH 1/7] add ppl to catalyst logical plan transformer Signed-off-by: YANGDB --- spark/build.gradle | 27 ++ .../sql/spark/ppl/CatalystPlanContext.java | 77 ++++ .../spark/ppl/CatalystQueryPlanVisitor.java | 356 ++++++++++++++++++ .../sql/spark/request/SparkQueryRequest.java | 2 + .../sql/spark/client/EmrClientImplTest.java | 12 +- .../sql/spark/constants/TestConstants.java | 3 +- .../SparkSqlFunctionImplementationTest.java | 10 +- .../SparkSqlFunctionTableScanBuilderTest.java | 6 +- ...SparkSqlFunctionTableScanOperatorTest.java | 16 +- .../SparkSqlTableFunctionResolverTest.java | 12 +- .../spark/ppl/SparkPPLLogicalBuilderTest.java | 119 ++++++ .../sql/spark/storage/SparkScanTest.java | 6 +- .../sql/spark/storage/SparkTableTest.java | 6 +- 13 files changed, 617 insertions(+), 35 deletions(-) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystPlanContext.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystQueryPlanVisitor.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/ppl/SparkPPLLogicalBuilderTest.java diff --git a/spark/build.gradle b/spark/build.gradle index 89842e5ea8..e06748f6a7 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -16,10 +16,37 @@ repositories { dependencies { api project(':core') implementation project(':datasources') + implementation project(':ppl') implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation group: 'org.json', name: 'json', version: '20230227' implementation group: 'com.amazonaws', name: 'aws-java-sdk-emr', version: '1.12.1' + implementation group: 'org.apache.spark', name: 'spark-catalyst_2.12', version: '3.4.0' + + implementation('com.fasterxml.jackson.core:jackson-annotations:2.15.2') {force = true} + implementation('com.fasterxml.jackson.module:jackson-module-scala_2.12:2.15.2') {force = true} + implementation('com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.15.2') {force = true} + implementation('org.apache.logging.log4j:log4j-api:2.20.0') {force = true} + implementation('org.apache.logging.log4j:log4j-core:2.20.0') {force = true} + implementation('com.google.protobuf:protobuf-java:3.22.3') {force = true} + implementation('com.google.code.findbugs:jsr305:3.0.2') {force = true} + implementation('org.objenesis:objenesis:3.2') {force = true} + implementation('org.antlr:antlr4-runtime:4.9.3') {force = true} + implementation('org.javassist:javassist:3.26.0-GA') {force = true} + implementation('com.github.luben:zstd-jni:1.5.5-5') {force = true} + implementation('org.scala-lang:scala-library:2.12.17') {force = true} + implementation('commons-io:commons-io:2.11.0') {force = true} + implementation('org.apache.zookeeper:zookeeper:3.6.3') {force = true} + implementation('org.apache.commons:commons-compress:1.22') {force = true} + implementation('org.xerial.snappy:snappy-java:1.1.9.1') {force = true} + implementation('io.netty:netty-transport-native-epoll:4.1.63.Final') {force = true} + implementation('io.netty:netty-handler:4.1.63.Final') {force = true} + implementation('io.netty:netty-buffer:4.1.63.Final') {force = true} + implementation('io.netty:netty-codec:4.1.63.Final') {force = true} + implementation('io.netty:netty-common:4.1.63.Final') {force = true} + implementation('io.netty:netty-transport-native-unix-common:4.1.63.Final') {force = true} + implementation('io.netty:netty-resolver:4.1.63.Final') {force = true} + implementation('io.netty:netty-transport:4.1.63.Final') {force = true} testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.2.0' diff --git a/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystPlanContext.java b/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystPlanContext.java new file mode 100644 index 0000000000..9d67e7ce43 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystPlanContext.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.ppl; + +import lombok.Getter; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.opensearch.sql.analysis.TypeEnvironment; +import org.opensearch.sql.expression.NamedExpression; +import org.opensearch.sql.expression.function.FunctionProperties; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +/** The context used for Catalyst logical plan. */ +public class CatalystPlanContext { + /** Environment stack for symbol scope management. */ + private TypeEnvironment environment; + + @Getter private LogicalPlan plan; + + @Getter private final List namedParseExpressions; + + @Getter private final FunctionProperties functionProperties; + + public CatalystPlanContext() { + this(new TypeEnvironment(null)); + } + + /** + * Class CTOR. + * + * @param environment Env to set to a new instance. + */ + public CatalystPlanContext(TypeEnvironment environment) { + this.environment = environment; + this.namedParseExpressions = new ArrayList<>(); + this.functionProperties = new FunctionProperties(); + } + + /** Push a new environment. */ + public void push() { + environment = new TypeEnvironment(environment); + } + + /** + * Return current environment. + * + * @return current environment + */ + public TypeEnvironment peek() { + return environment; + } + + /** + * update context with evolving plan + * @param plan + */ + public void plan(LogicalPlan plan) { + this.plan = plan; + } + /** + * Pop up current environment from environment chain. + * + * @return current environment (before pop) + */ + public TypeEnvironment pop() { + Objects.requireNonNull(environment, "Fail to pop context due to no environment present"); + + TypeEnvironment curEnv = environment; + environment = curEnv.getParent(); + return curEnv; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystQueryPlanVisitor.java b/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystQueryPlanVisitor.java new file mode 100644 index 0000000000..7c888aac21 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystQueryPlanVisitor.java @@ -0,0 +1,356 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.ppl; + +import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; +import org.apache.spark.sql.catalyst.analysis.UnresolvedTable; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.AggregateFunction; +import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.AllFields; +import org.opensearch.sql.ast.expression.And; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.Compare; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.Interval; +import org.opensearch.sql.ast.expression.Let; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.Map; +import org.opensearch.sql.ast.expression.Not; +import org.opensearch.sql.ast.expression.Or; +import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.expression.Xor; +import org.opensearch.sql.ast.statement.Explain; +import org.opensearch.sql.ast.statement.Query; +import org.opensearch.sql.ast.statement.Statement; +import org.opensearch.sql.ast.tree.Aggregation; +import org.opensearch.sql.ast.tree.Dedupe; +import org.opensearch.sql.ast.tree.Eval; +import org.opensearch.sql.ast.tree.Filter; +import org.opensearch.sql.ast.tree.Head; +import org.opensearch.sql.ast.tree.Project; +import org.opensearch.sql.ast.tree.RareTopN; +import org.opensearch.sql.ast.tree.Relation; +import org.opensearch.sql.ast.tree.Rename; +import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.ast.tree.TableFunction; +import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.planner.logical.LogicalAggregation; +import org.opensearch.sql.planner.logical.LogicalDedupe; +import org.opensearch.sql.planner.logical.LogicalEval; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.planner.logical.LogicalRareTopN; +import org.opensearch.sql.planner.logical.LogicalRemove; +import org.opensearch.sql.planner.logical.LogicalRename; +import org.opensearch.sql.planner.logical.LogicalSort; +import scala.Option; +import scala.collection.JavaConverters; +import scala.collection.Seq; + +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static java.util.List.of; +import static scala.Option.empty; +import static scala.collection.JavaConverters.asScalaBuffer; + +/** + * Utility class to mask sensitive information in incoming PPL queries. + */ +public class CatalystQueryPlanVisitor extends AbstractNodeVisitor { + + private static final String MASK_LITERAL = "***"; + + private final ExpressionAnalyzer expressionAnalyzer; + + public CatalystQueryPlanVisitor() { + this.expressionAnalyzer = new ExpressionAnalyzer(); + } + + public String visit(Statement plan,CatalystPlanContext context) { + return plan.accept(this,context); + } + + /** + * Handle Query Statement. + */ + @Override + public String visitQuery(Query node, CatalystPlanContext context) { + return node.getPlan().accept(this, context); + } + + @Override + public String visitExplain(Explain node, CatalystPlanContext context) { + return node.getStatement().accept(this, context); + } + + @Override + public String visitRelation(Relation node, CatalystPlanContext context) { + QualifiedName qualifiedName = node.getTableQualifiedName(); + // todo - how to resolve the qualifiedName is its composed of a datasource + schema + // Create an UnresolvedTable node for a table named "qualifiedName" in the default namespace + String command = StringUtils.format("source=%s", node.getTableName()); + context.plan(new UnresolvedTable(asScalaBuffer(of(qualifiedName.toString())).toSeq(), command, empty())); + return command; + } + + @Override + public String visitTableFunction(TableFunction node, CatalystPlanContext context) { + String arguments = + node.getArguments().stream() + .map( + unresolvedExpression -> + this.expressionAnalyzer.analyze(unresolvedExpression, context)) + .collect(Collectors.joining(",")); + return StringUtils.format("source=%s(%s)", node.getFunctionName().toString(), arguments); + } + + @Override + public String visitFilter(Filter node, CatalystPlanContext context) { + String child = node.getChild().get(0).accept(this, context); + String condition = visitExpression(node.getCondition(),context); + return StringUtils.format("%s | where %s", child, condition); + } + + /** + * Build {@link LogicalRename}. + */ + @Override + public String visitRename(Rename node, CatalystPlanContext context) { + String child = node.getChild().get(0).accept(this, context); + ImmutableMap.Builder renameMapBuilder = new ImmutableMap.Builder<>(); + for (Map renameMap : node.getRenameList()) { + renameMapBuilder.put( + visitExpression(renameMap.getOrigin(),context), + ((Field) renameMap.getTarget()).getField().toString()); + } + String renames = + renameMapBuilder.build().entrySet().stream() + .map(entry -> StringUtils.format("%s as %s", entry.getKey(), entry.getValue())) + .collect(Collectors.joining(",")); + return StringUtils.format("%s | rename %s", child, renames); + } + + /** + * Build {@link LogicalAggregation}. + */ + @Override + public String visitAggregation(Aggregation node, CatalystPlanContext context) { + String child = node.getChild().get(0).accept(this, context); + final String group = visitExpressionList(node.getGroupExprList(),context); + return StringUtils.format( + "%s | stats %s", + child, String.join(" ", visitExpressionList(node.getAggExprList(),context), groupBy(group)).trim()); + } + + /** + * Build {@link LogicalRareTopN}. + */ + @Override + public String visitRareTopN(RareTopN node, CatalystPlanContext context) { + final String child = node.getChild().get(0).accept(this, context); + List options = node.getNoOfResults(); + Integer noOfResults = (Integer) options.get(0).getValue().getValue(); + String fields = visitFieldList(node.getFields(),context); + String group = visitExpressionList(node.getGroupExprList(),context); + return StringUtils.format( + "%s | %s %d %s", + child, + node.getCommandType().name().toLowerCase(), + noOfResults, + String.join(" ", fields, groupBy(group)).trim()); + } + + /** + * Build {@link LogicalProject} or {@link LogicalRemove} from {@link Field}. + */ + @Override + public String visitProject(Project node, CatalystPlanContext context) { + String child = node.getChild().get(0).accept(this, context); + String arg = "+"; + String fields = visitExpressionList(node.getProjectList(),context); + + if (node.hasArgument()) { + Argument argument = node.getArgExprList().get(0); + Boolean exclude = (Boolean) argument.getValue().getValue(); + if (exclude) { + arg = "-"; + } + } + return StringUtils.format("%s | fields %s %s", child, arg, fields); + } + + /** + * Build {@link LogicalEval}. + */ + @Override + public String visitEval(Eval node, CatalystPlanContext context) { + String child = node.getChild().get(0).accept(this, context); + ImmutableList.Builder> expressionsBuilder = new ImmutableList.Builder<>(); + for (Let let : node.getExpressionList()) { + String expression = visitExpression(let.getExpression(),context); + String target = let.getVar().getField().toString(); + expressionsBuilder.add(ImmutablePair.of(target, expression)); + } + String expressions = + expressionsBuilder.build().stream() + .map(pair -> StringUtils.format("%s" + "=%s", pair.getLeft(), pair.getRight())) + .collect(Collectors.joining(" ")); + return StringUtils.format("%s | eval %s", child, expressions); + } + + /** + * Build {@link LogicalSort}. + */ + @Override + public String visitSort(Sort node, CatalystPlanContext context) { + String child = node.getChild().get(0).accept(this, context); + // the first options is {"count": "integer"} + String sortList = visitFieldList(node.getSortList(),context); + return StringUtils.format("%s | sort %s", child, sortList); + } + + /** + * Build {@link LogicalDedupe}. + */ + @Override + public String visitDedupe(Dedupe node, CatalystPlanContext context) { + String child = node.getChild().get(0).accept(this, context); + String fields = visitFieldList(node.getFields(),context); + List options = node.getOptions(); + Integer allowedDuplication = (Integer) options.get(0).getValue().getValue(); + Boolean keepEmpty = (Boolean) options.get(1).getValue().getValue(); + Boolean consecutive = (Boolean) options.get(2).getValue().getValue(); + + return StringUtils.format( + "%s | dedup %s %d keepempty=%b consecutive=%b", + child, fields, allowedDuplication, keepEmpty, consecutive); + } + + @Override + public String visitHead(Head node, CatalystPlanContext context) { + String child = node.getChild().get(0).accept(this, context); + Integer size = node.getSize(); + return StringUtils.format("%s | head %d", child, size); + } + + private String visitFieldList(List fieldList, CatalystPlanContext context) { + return fieldList.stream().map(field->visitExpression(field,context)).collect(Collectors.joining(",")); + } + + private String visitExpressionList(List expressionList,CatalystPlanContext context) { + return expressionList.isEmpty() + ? "" + : expressionList.stream().map(field->visitExpression(field,context)).collect(Collectors.joining(",")); + } + + private String visitExpression(UnresolvedExpression expression,CatalystPlanContext context) { + return expressionAnalyzer.analyze(expression, context); + } + + private String groupBy(String groupBy) { + return Strings.isNullOrEmpty(groupBy) ? "" : StringUtils.format("by %s", groupBy); + } + + /** + * Expression Analyzer. + */ + private static class ExpressionAnalyzer extends AbstractNodeVisitor { + + public String analyze(UnresolvedExpression unresolved, CatalystPlanContext context) { + return unresolved.accept(this, context); + } + + @Override + public String visitLiteral(Literal node, CatalystPlanContext context) { + return MASK_LITERAL; + } + + @Override + public String visitInterval(Interval node, CatalystPlanContext context) { + String value = node.getValue().accept(this, context); + String unit = node.getUnit().name(); + return StringUtils.format("INTERVAL %s %s", value, unit); + } + + @Override + public String visitAnd(And node, CatalystPlanContext context) { + String left = node.getLeft().accept(this, context); + String right = node.getRight().accept(this, context); + return StringUtils.format("%s and %s", left, right); + } + + @Override + public String visitOr(Or node, CatalystPlanContext context) { + String left = node.getLeft().accept(this, context); + String right = node.getRight().accept(this, context); + return StringUtils.format("%s or %s", left, right); + } + + @Override + public String visitXor(Xor node, CatalystPlanContext context) { + String left = node.getLeft().accept(this, context); + String right = node.getRight().accept(this, context); + return StringUtils.format("%s xor %s", left, right); + } + + @Override + public String visitNot(Not node, CatalystPlanContext context) { + String expr = node.getExpression().accept(this, context); + return StringUtils.format("not %s", expr); + } + + @Override + public String visitAggregateFunction(AggregateFunction node, CatalystPlanContext context) { + String arg = node.getField().accept(this, context); + return StringUtils.format("%s(%s)", node.getFuncName(), arg); + } + + @Override + public String visitFunction(Function node, CatalystPlanContext context) { + String arguments = + node.getFuncArgs().stream() + .map(unresolvedExpression -> analyze(unresolvedExpression, context)) + .collect(Collectors.joining(",")); + return StringUtils.format("%s(%s)", node.getFuncName(), arguments); + } + + @Override + public String visitCompare(Compare node, CatalystPlanContext context) { + String left = analyze(node.getLeft(), context); + String right = analyze(node.getRight(), context); + return StringUtils.format("%s %s %s", left, node.getOperator(), right); + } + + @Override + public String visitField(Field node, CatalystPlanContext context) { + return node.getField().toString(); + } + @Override + public String visitAllFields(AllFields node, CatalystPlanContext context) { + // Create an UnresolvedStar for all-fields projection + Seq projectList = JavaConverters.asScalaBuffer(Collections.singletonList((Object) UnresolvedStar$.MODULE$.apply(Option.>empty()))).toSeq(); + // Create a Project node with the UnresolvedStar + context.plan(new org.apache.spark.sql.catalyst.plans.logical.Project((Seq)projectList, context.getPlan())); + return "*"; + } + + @Override + public String visitAlias(Alias node, CatalystPlanContext context) { + String expr = node.getDelegated().accept(this, context); + return StringUtils.format("%s", expr); + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/request/SparkQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/request/SparkQueryRequest.java index 94c9795161..6d719f4cf4 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/request/SparkQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/request/SparkQueryRequest.java @@ -13,4 +13,6 @@ public class SparkQueryRequest { /** SQL. */ private String sql; + /** PPL. */ + private String ppl; } diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EmrClientImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EmrClientImplTest.java index 93dc0d6bc8..e21674d02c 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/client/EmrClientImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EmrClientImplTest.java @@ -8,7 +8,7 @@ import static org.mockito.Mockito.any; import static org.mockito.Mockito.when; import static org.opensearch.sql.spark.constants.TestConstants.EMR_CLUSTER_ID; -import static org.opensearch.sql.spark.constants.TestConstants.QUERY; +import static org.opensearch.sql.spark.constants.TestConstants.SQL_QUERY; import static org.opensearch.sql.spark.utils.TestUtils.getJson; import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; @@ -49,7 +49,7 @@ void testRunEmrApplication() { EmrClientImpl emrClientImpl = new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponse, null); - emrClientImpl.runEmrApplication(QUERY); + emrClientImpl.runEmrApplication(SQL_QUERY); } @Test @@ -70,7 +70,7 @@ void testRunEmrApplicationFailed() { new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponse, null); RuntimeException exception = Assertions.assertThrows( - RuntimeException.class, () -> emrClientImpl.runEmrApplication(QUERY)); + RuntimeException.class, () -> emrClientImpl.runEmrApplication(SQL_QUERY)); Assertions.assertEquals("Spark SQL application failed.", exception.getMessage()); } @@ -92,7 +92,7 @@ void testRunEmrApplicationCancelled() { new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponse, null); RuntimeException exception = Assertions.assertThrows( - RuntimeException.class, () -> emrClientImpl.runEmrApplication(QUERY)); + RuntimeException.class, () -> emrClientImpl.runEmrApplication(SQL_QUERY)); Assertions.assertEquals("Spark SQL application failed.", exception.getMessage()); } @@ -122,7 +122,7 @@ void testRunEmrApplicationRunnning() { EmrClientImpl emrClientImpl = new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponse, null); - emrClientImpl.runEmrApplication(QUERY); + emrClientImpl.runEmrApplication(SQL_QUERY); } @Test @@ -153,6 +153,6 @@ void testSql() { EmrClientImpl emrClientImpl = new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponse, null); - emrClientImpl.sql(QUERY); + emrClientImpl.sql(SQL_QUERY); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java index 2b1020568a..142d98b150 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java +++ b/spark/src/test/java/org/opensearch/sql/spark/constants/TestConstants.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.constants; public class TestConstants { - public static final String QUERY = "select 1"; + public static final String SQL_QUERY = "select 1"; + public static final String PPL_QUERY = "search source=accounts"; public static final String EMR_CLUSTER_ID = "j-123456789"; } diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionImplementationTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionImplementationTest.java index 120747e0d3..8e7d18882d 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionImplementationTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionImplementationTest.java @@ -8,7 +8,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.opensearch.sql.spark.constants.TestConstants.QUERY; +import static org.opensearch.sql.spark.constants.TestConstants.SQL_QUERY; import java.util.List; import org.junit.jupiter.api.Test; @@ -33,7 +33,7 @@ public class SparkSqlFunctionImplementationTest { void testValueOfAndTypeToString() { FunctionName functionName = new FunctionName("sql"); List namedArgumentExpressionList = - List.of(DSL.namedArgument("query", DSL.literal(QUERY))); + List.of(DSL.namedArgument("query", DSL.literal(SQL_QUERY))); SparkSqlFunctionImplementation sparkSqlFunctionImplementation = new SparkSqlFunctionImplementation(functionName, namedArgumentExpressionList, client); UnsupportedOperationException exception = @@ -51,13 +51,13 @@ void testValueOfAndTypeToString() { void testApplyArguments() { FunctionName functionName = new FunctionName("sql"); List namedArgumentExpressionList = - List.of(DSL.namedArgument("query", DSL.literal(QUERY))); + List.of(DSL.namedArgument("query", DSL.literal(SQL_QUERY))); SparkSqlFunctionImplementation sparkSqlFunctionImplementation = new SparkSqlFunctionImplementation(functionName, namedArgumentExpressionList, client); SparkTable sparkTable = (SparkTable) sparkSqlFunctionImplementation.applyArguments(); assertNotNull(sparkTable.getSparkQueryRequest()); SparkQueryRequest sparkQueryRequest = sparkTable.getSparkQueryRequest(); - assertEquals(QUERY, sparkQueryRequest.getSql()); + assertEquals(SQL_QUERY, sparkQueryRequest.getSql()); } @Test @@ -65,7 +65,7 @@ void testApplyArgumentsException() { FunctionName functionName = new FunctionName("sql"); List namedArgumentExpressionList = List.of( - DSL.namedArgument("query", DSL.literal(QUERY)), + DSL.namedArgument("query", DSL.literal(SQL_QUERY)), DSL.namedArgument("tmp", DSL.literal(12345))); SparkSqlFunctionImplementation sparkSqlFunctionImplementation = new SparkSqlFunctionImplementation(functionName, namedArgumentExpressionList, client); diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanBuilderTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanBuilderTest.java index 212056eb15..6bb0f2d1d7 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanBuilderTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanBuilderTest.java @@ -5,7 +5,7 @@ package org.opensearch.sql.spark.functions; -import static org.opensearch.sql.spark.constants.TestConstants.QUERY; +import static org.opensearch.sql.spark.constants.TestConstants.SQL_QUERY; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -25,7 +25,7 @@ public class SparkSqlFunctionTableScanBuilderTest { @Test void testBuild() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); - sparkQueryRequest.setSql(QUERY); + sparkQueryRequest.setSql(SQL_QUERY); SparkSqlFunctionTableScanBuilder sparkSqlFunctionTableScanBuilder = new SparkSqlFunctionTableScanBuilder(sparkClient, sparkQueryRequest); @@ -37,7 +37,7 @@ void testBuild() { @Test void testPushProject() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); - sparkQueryRequest.setSql(QUERY); + sparkQueryRequest.setSql(SQL_QUERY); SparkSqlFunctionTableScanBuilder sparkSqlFunctionTableScanBuilder = new SparkSqlFunctionTableScanBuilder(sparkClient, sparkQueryRequest); diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanOperatorTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanOperatorTest.java index 586f0ef2d8..0bf55e895c 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanOperatorTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanOperatorTest.java @@ -10,7 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; -import static org.opensearch.sql.spark.constants.TestConstants.QUERY; +import static org.opensearch.sql.spark.constants.TestConstants.SQL_QUERY; import static org.opensearch.sql.spark.utils.TestUtils.getJson; import java.io.IOException; @@ -49,7 +49,7 @@ public class SparkSqlFunctionTableScanOperatorTest { @SneakyThrows void testEmptyQueryWithException() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); - sparkQueryRequest.setSql(QUERY); + sparkQueryRequest.setSql(SQL_QUERY); SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); @@ -65,7 +65,7 @@ void testEmptyQueryWithException() { @SneakyThrows void testClose() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); - sparkQueryRequest.setSql(QUERY); + sparkQueryRequest.setSql(SQL_QUERY); SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); @@ -76,7 +76,7 @@ void testClose() { @SneakyThrows void testExplain() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); - sparkQueryRequest.setSql(QUERY); + sparkQueryRequest.setSql(SQL_QUERY); SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); @@ -88,7 +88,7 @@ void testExplain() { @SneakyThrows void testQueryResponseIterator() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); - sparkQueryRequest.setSql(QUERY); + sparkQueryRequest.setSql(SQL_QUERY); SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); @@ -111,7 +111,7 @@ void testQueryResponseIterator() { @SneakyThrows void testQueryResponseAllTypes() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); - sparkQueryRequest.setSql(QUERY); + sparkQueryRequest.setSql(SQL_QUERY); SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); @@ -144,7 +144,7 @@ void testQueryResponseAllTypes() { @SneakyThrows void testQueryResponseInvalidDataType() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); - sparkQueryRequest.setSql(QUERY); + sparkQueryRequest.setSql(SQL_QUERY); SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); @@ -161,7 +161,7 @@ void testQueryResponseInvalidDataType() { @SneakyThrows void testQuerySchema() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); - sparkQueryRequest.setSql(QUERY); + sparkQueryRequest.setSql(SQL_QUERY); SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlTableFunctionResolverTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlTableFunctionResolverTest.java index a828ac76c4..b8b400a843 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlTableFunctionResolverTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlTableFunctionResolverTest.java @@ -10,7 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.opensearch.sql.data.type.ExprCoreType.STRING; -import static org.opensearch.sql.spark.constants.TestConstants.QUERY; +import static org.opensearch.sql.spark.constants.TestConstants.SQL_QUERY; import java.util.List; import java.util.stream.Collectors; @@ -44,7 +44,7 @@ void testResolve() { SparkSqlTableFunctionResolver sqlTableFunctionResolver = new SparkSqlTableFunctionResolver(client); FunctionName functionName = FunctionName.of("sql"); - List expressions = List.of(DSL.namedArgument("query", DSL.literal(QUERY))); + List expressions = List.of(DSL.namedArgument("query", DSL.literal(SQL_QUERY))); FunctionSignature functionSignature = new FunctionSignature( functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); @@ -60,7 +60,7 @@ void testResolve() { SparkTable sparkTable = (SparkTable) functionImplementation.applyArguments(); assertNotNull(sparkTable.getSparkQueryRequest()); SparkQueryRequest sparkQueryRequest = sparkTable.getSparkQueryRequest(); - assertEquals(QUERY, sparkQueryRequest.getSql()); + assertEquals(SQL_QUERY, sparkQueryRequest.getSql()); } @Test @@ -68,7 +68,7 @@ void testArgumentsPassedByPosition() { SparkSqlTableFunctionResolver sqlTableFunctionResolver = new SparkSqlTableFunctionResolver(client); FunctionName functionName = FunctionName.of("sql"); - List expressions = List.of(DSL.namedArgument(null, DSL.literal(QUERY))); + List expressions = List.of(DSL.namedArgument(null, DSL.literal(SQL_QUERY))); FunctionSignature functionSignature = new FunctionSignature( functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); @@ -86,7 +86,7 @@ void testArgumentsPassedByPosition() { SparkTable sparkTable = (SparkTable) functionImplementation.applyArguments(); assertNotNull(sparkTable.getSparkQueryRequest()); SparkQueryRequest sparkQueryRequest = sparkTable.getSparkQueryRequest(); - assertEquals(QUERY, sparkQueryRequest.getSql()); + assertEquals(SQL_QUERY, sparkQueryRequest.getSql()); } @Test @@ -96,7 +96,7 @@ void testMixedArgumentTypes() { FunctionName functionName = FunctionName.of("sql"); List expressions = List.of( - DSL.namedArgument("query", DSL.literal(QUERY)), + DSL.namedArgument("query", DSL.literal(SQL_QUERY)), DSL.namedArgument(null, DSL.literal(12345))); FunctionSignature functionSignature = new FunctionSignature( diff --git a/spark/src/test/java/org/opensearch/sql/spark/ppl/SparkPPLLogicalBuilderTest.java b/spark/src/test/java/org/opensearch/sql/spark/ppl/SparkPPLLogicalBuilderTest.java new file mode 100644 index 0000000000..1987b94fcd --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/ppl/SparkPPLLogicalBuilderTest.java @@ -0,0 +1,119 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.ppl; + +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute; +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; +import org.apache.spark.sql.catalyst.analysis.UnresolvedFieldName$; +import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; +import org.apache.spark.sql.catalyst.analysis.UnresolvedTable; +import org.apache.spark.sql.catalyst.analysis.UnresolvedFieldName; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.Project; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.opensearch.sql.ast.statement.Statement; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.ppl.antlr.PPLSyntaxParser; +import org.opensearch.sql.ppl.parser.AstBuilder; +import org.opensearch.sql.ppl.parser.AstExpressionBuilder; +import org.opensearch.sql.ppl.parser.AstStatementBuilder; +import org.opensearch.sql.spark.client.SparkClient; +import scala.Option; +import scala.collection.JavaConverters; +import scala.collection.Seq; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static java.util.List.of; +import static scala.Option.empty; +import static scala.collection.JavaConverters.asScalaBuffer; + + +public class SparkPPLLogicalBuilderTest { + private PPLSyntaxParser parser = new PPLSyntaxParser(); + @Mock + private SparkClient sparkClient; + + @Mock + private LogicalProject logicalProject; + private CatalystPlanContext context = new CatalystPlanContext(); + + private Statement plan(String query, boolean isExplain) { + final AstStatementBuilder builder = + new AstStatementBuilder( + new AstBuilder(new AstExpressionBuilder(), query), + AstStatementBuilder.StatementBuilderContext.builder().isExplain(isExplain).build()); + return builder.visit(parser.parse(query)); + } + + @Test + void testSearchWithTableAllFieldsPlan() { + Statement plan = plan("search source = table ", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan, context); + Seq projectList = JavaConverters.asScalaBuffer(Collections.singletonList((Object) UnresolvedStar$.MODULE$.apply(Option.>empty()))).toSeq(); + Project expectedPlan = new Project((Seq) projectList, new UnresolvedTable(asScalaBuffer(of("table")).toSeq(), "source=table ", Option.empty())); + Assertions.assertEquals(context.getPlan().toString(), expectedPlan.toString()); + } + @Test + void testSourceWithTableAllFieldsPlan() { + Statement plan = plan("source = table ", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan, context); + Seq projectList = JavaConverters.asScalaBuffer(Collections.singletonList((Object) UnresolvedStar$.MODULE$.apply(Option.>empty()))).toSeq(); + Project expectedPlan = new Project((Seq) projectList, new UnresolvedTable(asScalaBuffer(of("table")).toSeq(), "source=table ", Option.empty())); + Assertions.assertEquals(context.getPlan().toString(), expectedPlan.toString()); + } + + @Test + void testSourceWithTableOneFieldPlan() { + Statement plan = plan("source=table | fields A", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan, context); + // Create a Project node for fields A and B + List projectList = Arrays.asList( + UnresolvedAttribute$.MODULE$.apply(JavaConverters.asScalaBuffer(Collections.singletonList("A"))) + ); + Project expectedPlan = new Project(JavaConverters.asScalaBuffer(projectList).toSeq(), new UnresolvedTable(asScalaBuffer(of("table")).toSeq(), "source=table ", Option.empty())); + Assertions.assertEquals(context.getPlan().toString(), expectedPlan.toString()); + } + + @Test + void testSourceWithTableTwoFieldPlan() { + Statement plan = plan("source=table | fields A, B", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan, context); + // Create a Project node for fields A and B + List projectList = Arrays.asList( + UnresolvedAttribute$.MODULE$.apply(JavaConverters.asScalaBuffer(Collections.singletonList("A"))), + UnresolvedAttribute$.MODULE$.apply(JavaConverters.asScalaBuffer(Collections.singletonList("B"))) + ); + Project expectedPlan = new Project(JavaConverters.asScalaBuffer(projectList).toSeq(), new UnresolvedTable(asScalaBuffer(of("table")).toSeq(), "source=table ", Option.empty())); + Assertions.assertEquals(context.getPlan().toString(), expectedPlan.toString()); + } + + @Test + void testSearchWithMultiTablesPlan() { + Statement plan = plan("search source = table1, table2 ", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan,context); + } + + @Test + void testSearchWithWildcardBasedTableNamePlanException() { + Statement plan = plan("search source = table1, table2 ", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan,context); + } + +} + + diff --git a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkScanTest.java b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkScanTest.java index 971db3c33c..3b46a4a304 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkScanTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkScanTest.java @@ -7,7 +7,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; -import static org.opensearch.sql.spark.constants.TestConstants.QUERY; +import static org.opensearch.sql.spark.constants.TestConstants.SQL_QUERY; import lombok.SneakyThrows; import org.junit.jupiter.api.Assertions; @@ -25,7 +25,7 @@ public class SparkScanTest { @SneakyThrows void testQueryResponseIteratorForQueryRangeFunction() { SparkScan sparkScan = new SparkScan(sparkClient); - sparkScan.getRequest().setSql(QUERY); + sparkScan.getRequest().setSql(SQL_QUERY); Assertions.assertFalse(sparkScan.hasNext()); assertNull(sparkScan.next()); } @@ -34,7 +34,7 @@ void testQueryResponseIteratorForQueryRangeFunction() { @SneakyThrows void testExplain() { SparkScan sparkScan = new SparkScan(sparkClient); - sparkScan.getRequest().setSql(QUERY); + sparkScan.getRequest().setSql(SQL_QUERY); assertEquals("SparkQueryRequest(sql=select 1)", sparkScan.explain()); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkTableTest.java b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkTableTest.java index a70d4ba69e..386e544e6e 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkTableTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkTableTest.java @@ -10,7 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.opensearch.sql.spark.constants.TestConstants.QUERY; +import static org.opensearch.sql.spark.constants.TestConstants.SQL_QUERY; import java.util.Collections; import java.util.HashMap; @@ -46,7 +46,7 @@ void testUnsupportedOperation() { @Test void testCreateScanBuilderWithSqlTableFunction() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); - sparkQueryRequest.setSql(QUERY); + sparkQueryRequest.setSql(SQL_QUERY); SparkTable sparkTable = new SparkTable(client, sparkQueryRequest); TableScanBuilder tableScanBuilder = sparkTable.createScanBuilder(); Assertions.assertNotNull(tableScanBuilder); @@ -68,7 +68,7 @@ void testGetFieldTypesFromSparkQueryRequest() { @Test void testImplementWithSqlFunction() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); - sparkQueryRequest.setSql(QUERY); + sparkQueryRequest.setSql(SQL_QUERY); SparkTable sparkMetricTable = new SparkTable(client, sparkQueryRequest); PhysicalPlan plan = sparkMetricTable.implement(new SparkSqlFunctionTableScanBuilder(client, sparkQueryRequest)); From 9a8700e6c59b98942571b814c7ae6e26c0abbdcf Mon Sep 17 00:00:00 2001 From: YANGDB Date: Tue, 29 Aug 2023 19:05:23 -0700 Subject: [PATCH 2/7] update tests Signed-off-by: YANGDB --- .../sql/spark/ppl/CatalystPlanContext.java | 2 +- .../spark/ppl/CatalystQueryPlanVisitor.java | 14 ++++-- ...LToCatalystLogicalPlanTranslatorTest.java} | 49 +++++++++++++------ 3 files changed, 46 insertions(+), 19 deletions(-) rename spark/src/test/java/org/opensearch/sql/spark/ppl/{SparkPPLLogicalBuilderTest.java => PPLToCatalystLogicalPlanTranslatorTest.java} (69%) diff --git a/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystPlanContext.java b/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystPlanContext.java index 9d67e7ce43..e77cdb2072 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystPlanContext.java +++ b/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystPlanContext.java @@ -6,9 +6,9 @@ package org.opensearch.sql.spark.ppl; import lombok.Getter; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.opensearch.sql.analysis.TypeEnvironment; -import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.function.FunctionProperties; import java.util.ArrayList; diff --git a/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystQueryPlanVisitor.java b/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystQueryPlanVisitor.java index 7c888aac21..3e7b14ed30 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystQueryPlanVisitor.java +++ b/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystQueryPlanVisitor.java @@ -10,6 +10,7 @@ import com.google.common.collect.ImmutableMap; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.analysis.UnresolvedTable; import org.apache.spark.sql.catalyst.expressions.NamedExpression; @@ -182,6 +183,11 @@ public String visitProject(Project node, CatalystPlanContext context) { String arg = "+"; String fields = visitExpressionList(node.getProjectList(),context); + // Create an UnresolvedStar for all-fields projection + Seq projectList = JavaConverters.asScalaBuffer(context.getNamedParseExpressions()).toSeq(); + // Create a Project node with the UnresolvedStar + context.plan(new org.apache.spark.sql.catalyst.plans.logical.Project((Seq)projectList, context.getPlan())); + if (node.hasArgument()) { Argument argument = node.getArgExprList().get(0); Boolean exclude = (Boolean) argument.getValue().getValue(); @@ -253,7 +259,8 @@ private String visitFieldList(List fieldList, CatalystPlanContext context private String visitExpressionList(List expressionList,CatalystPlanContext context) { return expressionList.isEmpty() ? "" - : expressionList.stream().map(field->visitExpression(field,context)).collect(Collectors.joining(",")); + : expressionList.stream().map(field->visitExpression(field,context)) + .collect(Collectors.joining(",")); } private String visitExpression(UnresolvedExpression expression,CatalystPlanContext context) { @@ -336,14 +343,13 @@ public String visitCompare(Compare node, CatalystPlanContext context) { @Override public String visitField(Field node, CatalystPlanContext context) { + context.getNamedParseExpressions().add(UnresolvedAttribute$.MODULE$.apply(JavaConverters.asScalaBuffer(Collections.singletonList(node.getField().toString())))); return node.getField().toString(); } @Override public String visitAllFields(AllFields node, CatalystPlanContext context) { // Create an UnresolvedStar for all-fields projection - Seq projectList = JavaConverters.asScalaBuffer(Collections.singletonList((Object) UnresolvedStar$.MODULE$.apply(Option.>empty()))).toSeq(); - // Create a Project node with the UnresolvedStar - context.plan(new org.apache.spark.sql.catalyst.plans.logical.Project((Seq)projectList, context.getPlan())); + context.getNamedParseExpressions().add(UnresolvedStar$.MODULE$.apply(Option.>empty())); return "*"; } diff --git a/spark/src/test/java/org/opensearch/sql/spark/ppl/SparkPPLLogicalBuilderTest.java b/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java similarity index 69% rename from spark/src/test/java/org/opensearch/sql/spark/ppl/SparkPPLLogicalBuilderTest.java rename to spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java index 1987b94fcd..615c9baadb 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/ppl/SparkPPLLogicalBuilderTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java @@ -5,15 +5,17 @@ package org.opensearch.sql.spark.ppl; -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute; import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; -import org.apache.spark.sql.catalyst.analysis.UnresolvedFieldName$; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.analysis.UnresolvedTable; -import org.apache.spark.sql.catalyst.analysis.UnresolvedFieldName; +import org.apache.spark.sql.catalyst.expressions.EqualTo; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.Literal; import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.plans.logical.Filter; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.apache.spark.sql.catalyst.plans.logical.Project; +import org.apache.spark.sql.catalyst.plans.logical.Union; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.mockito.Mock; @@ -27,17 +29,18 @@ import scala.Option; import scala.collection.JavaConverters; import scala.collection.Seq; +import scala.reflect.internal.Trees; import java.util.Arrays; import java.util.Collections; import java.util.List; import static java.util.List.of; -import static scala.Option.empty; +import static org.apache.spark.sql.types.DataTypes.IntegerType; import static scala.collection.JavaConverters.asScalaBuffer; -public class SparkPPLLogicalBuilderTest { +public class PPLToCatalystLogicalPlanTranslatorTest { private PPLSyntaxParser parser = new PPLSyntaxParser(); @Mock private SparkClient sparkClient; @@ -53,7 +56,7 @@ private Statement plan(String query, boolean isExplain) { AstStatementBuilder.StatementBuilderContext.builder().isExplain(isExplain).build()); return builder.visit(parser.parse(query)); } - + @Test void testSearchWithTableAllFieldsPlan() { Statement plan = plan("search source = table ", false); @@ -63,6 +66,7 @@ void testSearchWithTableAllFieldsPlan() { Project expectedPlan = new Project((Seq) projectList, new UnresolvedTable(asScalaBuffer(of("table")).toSeq(), "source=table ", Option.empty())); Assertions.assertEquals(context.getPlan().toString(), expectedPlan.toString()); } + @Test void testSourceWithTableAllFieldsPlan() { Statement plan = plan("source = table ", false); @@ -86,6 +90,22 @@ void testSourceWithTableOneFieldPlan() { Assertions.assertEquals(context.getPlan().toString(), expectedPlan.toString()); } + @Test + void testSourceWithTableAndConditionPlan() { + Statement plan = plan("source=t a = 1 ", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan, context); + // Create a Project node for fields A and B + List projectList = Arrays.asList( + UnresolvedAttribute$.MODULE$.apply(JavaConverters.asScalaBuffer(Collections.singletonList("a"))) + ); + UnresolvedTable table = new UnresolvedTable(asScalaBuffer(of("table")).toSeq(), "source=table ", Option.empty()); + // Create a Filter node for the condition 'a = 1' + EqualTo filterCondition = new EqualTo((Expression)projectList.get(0), Literal.create(1,IntegerType)); + LogicalPlan filterPlan = new Filter(filterCondition, table); + Assertions.assertEquals(context.getPlan().toString(), filterPlan.toString()); + } + @Test void testSourceWithTableTwoFieldPlan() { Statement plan = plan("source=table | fields A, B", false); @@ -104,14 +124,15 @@ void testSourceWithTableTwoFieldPlan() { void testSearchWithMultiTablesPlan() { Statement plan = plan("search source = table1, table2 ", false); CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); - planVisitor.visit(plan,context); - } - - @Test - void testSearchWithWildcardBasedTableNamePlanException() { - Statement plan = plan("search source = table1, table2 ", false); - CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); - planVisitor.visit(plan,context); + planVisitor.visit(plan, context); + Seq projectList = JavaConverters.asScalaBuffer(Collections.singletonList((Object) UnresolvedStar$.MODULE$.apply(Option.>empty()))).toSeq(); + Project expectedPlanTable1 = new Project((Seq) projectList, new UnresolvedTable(asScalaBuffer(of("table1")).toSeq(), "source=table ", Option.empty())); + Project expectedPlanTable2 = new Project((Seq) projectList, new UnresolvedTable(asScalaBuffer(of("table2")).toSeq(), "source=table ", Option.empty())); + // Create a Union logical plan + Seq unionChildren = JavaConverters.asScalaBuffer(Arrays.asList((LogicalPlan) expectedPlanTable1, (LogicalPlan) expectedPlanTable2)).toSeq(); + // todo : parameterize the union options byName & allowMissingCol + LogicalPlan unionPlan = new Union(unionChildren, true, false); + Assertions.assertEquals(context.getPlan().toString(), unionPlan.toString()); } } From 839debc143aaee709bb7e328c2faab45875f1053 Mon Sep 17 00:00:00 2001 From: YANGDB Date: Wed, 30 Aug 2023 17:06:56 -0700 Subject: [PATCH 3/7] update with an SQL query builder for future support in the PPL into SQL translation path Adding multiple PPL queries for different test purpose Signed-off-by: YANGDB --- .../org/opensearch/sql/utils/Builder.java | 740 ++++++++++++++++++ .../opensearch/sql/utils/QueryBuilder.java | 13 + .../sql/utils/SQLQueryBuilderTest.java | 287 +++++++ .../sql/spark/ppl/CatalystPlanContext.java | 4 +- .../spark/ppl/CatalystQueryPlanVisitor.java | 6 +- ...PLToCatalystLogicalPlanTranslatorTest.java | 40 +- 6 files changed, 1084 insertions(+), 6 deletions(-) create mode 100644 core/src/main/java/org/opensearch/sql/utils/Builder.java create mode 100644 core/src/main/java/org/opensearch/sql/utils/QueryBuilder.java create mode 100644 core/src/test/java/org/opensearch/sql/utils/SQLQueryBuilderTest.java diff --git a/core/src/main/java/org/opensearch/sql/utils/Builder.java b/core/src/main/java/org/opensearch/sql/utils/Builder.java new file mode 100644 index 0000000000..c24bfb7c97 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/utils/Builder.java @@ -0,0 +1,740 @@ +package org.opensearch.sql.utils; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +/** + * The SQL (textual) query wrapper and fluent query builder + */ +public class Builder implements QueryBuilder { + private List parameters = new LinkedList<>(); + private String name; + private StringBuilder sqlBuilder; + + public static Builder instance() { + return new Builder(); + } + + @Override + public String build() { + return this.getSQL(); + } + + @Override + public QueryBuilder withName(String name) { + this.name = name; + return this; + } + + public Builder() { + this(new StringBuilder()); + } + + /** + * Constructs a query builder from an existing SQL query. + * + * @param sql The existing SQL query. + */ + public Builder(String sql) { + if (sql == null) { + throw new IllegalArgumentException(); + } + + sqlBuilder = new StringBuilder(); + + append(sql); + } + + private Builder(StringBuilder sqlBuilder) { + this.sqlBuilder = sqlBuilder; + } + + /** + * Creates a "select" query. + * + * @param columns The column names. + * @return The new {@link Builder} instance. + */ + public static Builder select(String... columns) { + if (columns == null || columns.length == 0) { + throw new IllegalArgumentException(); + } + + var sqlBuilder = new StringBuilder(); + + sqlBuilder.append("select "); + + var SQLQueryBuilder = new Builder(sqlBuilder); + + for (var i = 0; i < columns.length; i++) { + if (i > 0) { + SQLQueryBuilder.sqlBuilder.append(", "); + } + + SQLQueryBuilder.append(columns[i]); + } + + return SQLQueryBuilder; + } + + /** + * Appends a "from" clause to a query. + * + * @param tables The table names. + * @return The {@link Builder} instance. + */ + public Builder from(String... tables) { + if (tables == null || tables.length == 0) { + throw new IllegalArgumentException(); + } + + sqlBuilder.append(" from "); + sqlBuilder.append(String.join(", ", tables)); + + return this; + } + + /** + * Appends a "from" clause to a query. + * + * @param Builder A "select" subquery. + * @param alias The subquery's alias. + * @return The {@link Builder} instance. + */ + public Builder from(Builder Builder, String alias) { + if (Builder == null || alias == null) { + throw new IllegalArgumentException(); + } + + sqlBuilder.append(" from ("); + sqlBuilder.append(Builder.getSQL()); + sqlBuilder.append(") "); + sqlBuilder.append(alias); + + parameters.addAll(Builder.parameters); + + return this; + } + + /** + * Appends a "join" clause to a query. + * + * @param table The table name. + * @return The {@link Builder} instance. + */ + public Builder join(String table) { + if (table == null) { + throw new IllegalArgumentException(); + } + + sqlBuilder.append(" join "); + sqlBuilder.append(table); + + return this; + } + + /** + * Appends a "join" clause to a query. + * + * @param Builder A "select" subquery. + * @param alias The subquery's alias. + * @return The {@link Builder} instance. + */ + public Builder join(Builder Builder, String alias) { + if (Builder == null || alias == null) { + throw new IllegalArgumentException(); + } + + sqlBuilder.append(" join ("); + sqlBuilder.append(Builder.getSQL()); + sqlBuilder.append(") "); + sqlBuilder.append(alias); + + parameters.addAll(Builder.parameters); + + return this; + } + + /** + * Appends a "left join" clause to a query. + * + * @param table The table name. + * @return The {@link Builder} instance. + */ + public Builder leftJoin(String table) { + if (table == null) { + throw new IllegalArgumentException(); + } + + sqlBuilder.append(" left join "); + sqlBuilder.append(table); + + return this; + } + + /** + * Appends a "right join" clause to a query. + * + * @param table The table name. + * @return The {@link Builder} instance. + */ + public Builder rightJoin(String table) { + if (table == null) { + throw new IllegalArgumentException(); + } + + sqlBuilder.append(" right join "); + sqlBuilder.append(table); + + return this; + } + + /** + * Appends an "on" clause to a query. + * + * @param predicates The clause predicates. + * @return The {@link Builder} instance. + */ + public Builder on(String... predicates) { + return filter("on", predicates); + } + + /** + * Appends a "where" clause to a query. + * + * @param predicates The clause predicates. + * @return The {@link Builder} instance. + */ + public Builder where(String... predicates) { + return filter("where", predicates); + } + + private Builder filter(String clause, String... predicates) { + if (predicates == null) { + throw new IllegalArgumentException(); + } + + sqlBuilder.append(" "); + sqlBuilder.append(clause); + sqlBuilder.append(" "); + + for (var i = 0; i < predicates.length; i++) { + if (i > 0) { + sqlBuilder.append(" "); + } + + append(predicates[i]); + } + + return this; + } + + /** + * Creates an "and" conditional. + * + * @param predicates The conditional's predicates. + * @return The conditional text. + */ + public static String and(String... predicates) { + return conditional("and", predicates); + } + + /** + * Creates an "or" conditional. + * + * @param predicates The conditional's predicates. + * @return The conditional text. + */ + public static String or(String... predicates) { + return conditional("or", predicates); + } + + private static String conditional(String operator, String... predicates) { + if (predicates == null || predicates.length == 0) { + throw new IllegalArgumentException(); + } + + var stringBuilder = new StringBuilder(); + + stringBuilder.append(operator); + stringBuilder.append(" "); + + if (predicates.length > 1) { + stringBuilder.append("("); + } + + for (var i = 0; i < predicates.length; i++) { + if (i > 0) { + stringBuilder.append(" "); + } + + stringBuilder.append(predicates[i]); + } + + if (predicates.length > 1) { + stringBuilder.append(")"); + } + + return stringBuilder.toString(); + } + + /** + * Creates an "and" conditional group. + * + * @param predicates The group's predicates. + * @return The conditional text. + */ + public static String allOf(String... predicates) { + if (predicates == null || predicates.length == 0) { + throw new IllegalArgumentException(); + } + + return conditionalGroup("and", predicates); + } + + /** + * Creates an "or" conditional group. + * + * @param predicates The group's predicates. + * @return The conditional text. + */ + public static String anyOf(String... predicates) { + if (predicates == null || predicates.length == 0) { + throw new IllegalArgumentException(); + } + + return conditionalGroup("or", predicates); + } + + private static String conditionalGroup(String operator, String... predicates) { + if (predicates == null || predicates.length == 0) { + throw new IllegalArgumentException(); + } + + var stringBuilder = new StringBuilder(); + + stringBuilder.append("("); + + for (var i = 0; i < predicates.length; i++) { + if (i > 0) { + stringBuilder.append(" "); + stringBuilder.append(operator); + stringBuilder.append(" "); + } + + stringBuilder.append(predicates[i]); + } + + stringBuilder.append(")"); + + return stringBuilder.toString(); + } + + /** + * Creates an "equal to" conditional. + * + * @param Builder The conditional's subquery. + * @return The conditional text. + */ + public static String equalTo(Builder Builder) { + if (Builder == null) { + throw new IllegalArgumentException(); + } + + return String.format("= (%s)", Builder); + } + + /** + * Creates a "not equal to" conditional. + * + * @param Builder The conditional's subquery. + * @return The conditional text. + */ + public static String notEqualTo(Builder Builder) { + if (Builder == null) { + throw new IllegalArgumentException(); + } + + return String.format("!= (%s)", Builder); + } + + /** + * Creates an "in" conditional. + * + * @param Builder The conditional's subquery. + * @return The conditional text. + */ + public static String in(Builder Builder) { + if (Builder == null) { + throw new IllegalArgumentException(); + } + + return String.format("in (%s)", Builder); + } + + /** + * Creates a "not in" conditional. + * + * @param Builder The conditional's subquery. + * @return The conditional text. + */ + public static String notIn(Builder Builder) { + if (Builder == null) { + throw new IllegalArgumentException(); + } + + return String.format("not in (%s)", Builder); + } + + /** + * Creates an "exists" conditional. + * + * @param Builder The conditional's subquery. + * @return The conditional text. + */ + public static String exists(Builder Builder) { + if (Builder == null) { + throw new IllegalArgumentException(); + } + + return String.format("exists (%s)", Builder); + } + + /** + * Creates a "not exists" conditional. + * + * @param Builder The conditional's subquery. + * @return The conditional text. + */ + public static String notExists(Builder Builder) { + if (Builder == null) { + throw new IllegalArgumentException(); + } + + return String.format("not exists (%s)", Builder); + } + + /** + * Appends an "order by" clause to a query. + * + * @param columns The column names. + * @return The {@link Builder} instance. + */ + public Builder orderBy(String... columns) { + if (columns == null || columns.length == 0) { + throw new IllegalArgumentException(); + } + + sqlBuilder.append(" order by "); + sqlBuilder.append(String.join(", ", columns)); + + return this; + } + + /** + * Appends a "limit" clause to a query. + * + * @param count The limit count. + * @return The {@link Builder} instance. + */ + public Builder limit(int count) { + if (count < 0) { + throw new IllegalArgumentException(); + } + + sqlBuilder.append(" limit "); + sqlBuilder.append(count); + + return this; + } + + /** + * Appends a "for update" clause to a query. + * + * @return The {@link Builder} instance. + */ + public Builder forUpdate() { + sqlBuilder.append(" for update"); + + return this; + } + + /** + * Appends a "union" clause to a query. + * + * @param Builder The query builder to append. + * @return The {@link Builder} instance. + */ + public Builder union(Builder Builder) { + if (Builder == null) { + throw new IllegalArgumentException(); + } + + sqlBuilder.append(" union "); + sqlBuilder.append(Builder.getSQL()); + + parameters.addAll(Builder.parameters); + + return this; + } + + /** + * Creates an "insert into" query. + * + * @param table The table name. + * @return The new {@link Builder} instance. + */ + public static Builder insertInto(String table) { + if (table == null) { + throw new IllegalArgumentException(); + } + + var sqlBuilder = new StringBuilder(); + + sqlBuilder.append("insert into "); + sqlBuilder.append(table); + + return new Builder(sqlBuilder); + } + + /** + * Appends column values to an "insert into" query. + * + * @param values The values to insert. + * @return The {@link Builder} instance. + */ + public Builder values(Map values) { + if (values == null) { + throw new IllegalArgumentException(); + } + + sqlBuilder.append(" ("); + + List columns = new ArrayList<>(values.keySet()); + + var n = columns.size(); + + for (var i = 0; i < n; i++) { + if (i > 0) { + sqlBuilder.append(", "); + } + + sqlBuilder.append(columns.get(i)); + } + + sqlBuilder.append(") values ("); + + for (var i = 0; i < n; i++) { + if (i > 0) { + sqlBuilder.append(", "); + } + + encode(values.get(columns.get(i))); + } + + sqlBuilder.append(")"); + + return this; + } + + /** + * Creates an "update" query. + * + * @param table The table name. + * @return The new {@link Builder} instance. + */ + public static Builder update(String table) { + if (table == null) { + throw new IllegalArgumentException(); + } + + var sqlBuilder = new StringBuilder(); + + sqlBuilder.append("update "); + sqlBuilder.append(table); + + return new Builder(sqlBuilder); + } + + /** + * Appends column values to an "update" query. + * + * @param values The values to update. + * @return The {@link Builder} instance. + */ + public Builder set(Map values) { + if (values == null) { + throw new IllegalArgumentException(); + } + + sqlBuilder.append(" set "); + + var i = 0; + + for (Map.Entry entry : values.entrySet()) { + if (i > 0) { + sqlBuilder.append(", "); + } + + sqlBuilder.append(entry.getKey()); + sqlBuilder.append(" = "); + + encode(entry.getValue()); + + i++; + } + + return this; + } + + /** + * Creates a "delete from" query. + * + * @param table The table name. + * @return The new {@link Builder} instance. + */ + public static Builder deleteFrom(String table) { + if (table == null) { + throw new IllegalArgumentException(); + } + + var sqlBuilder = new StringBuilder(); + + sqlBuilder.append("delete from "); + sqlBuilder.append(table); + + return new Builder(sqlBuilder); + } + + + /** + * Returns the parameters parsed by the query builder. + * + * @return The parameters parsed by the query builder. + */ + public Collection getParameters() { + return Collections.unmodifiableList(parameters); + } + + /** + * Returns the generated SQL. + * + * @return The generated SQL. + */ + public String getSQL() { + return sqlBuilder.toString(); + } + + private void append(String sql) { + var quoted = false; + + var n = sql.length(); + var i = 0; + + while (i < n) { + var c = sql.charAt(i++); + + if (c == ':' && !quoted) { + var parameterBuilder = new StringBuilder(); + + while (i < n) { + c = sql.charAt(i); + + if (!Character.isJavaIdentifierPart(c)) { + break; + } + + parameterBuilder.append(c); + + i++; + } + + if (parameterBuilder.length() == 0) { + throw new IllegalArgumentException("Missing parameter name."); + } + + parameters.add(parameterBuilder.toString()); + + sqlBuilder.append("?"); + } else if (c == '?' && !quoted) { + parameters.add(null); + + sqlBuilder.append(c); + } else { + if (c == '\'') { + quoted = !quoted; + } + + sqlBuilder.append(c); + } + } + } + + private void encode(Object value) { + if (value instanceof String) { + var string = (String) value; + + if (string.startsWith(":") || string.equals("?")) { + append(string); + } else { + sqlBuilder.append("'"); + + for (int i = 0, n = string.length(); i < n; i++) { + var c = string.charAt(i); + + if (c == '\'') { + sqlBuilder.append(c); + } + + sqlBuilder.append(c); + } + + sqlBuilder.append("'"); + } + } else if (value instanceof Builder) { + var SQLQueryBuilder = (Builder) value; + + sqlBuilder.append("("); + sqlBuilder.append(SQLQueryBuilder.getSQL()); + sqlBuilder.append(")"); + + parameters.addAll(SQLQueryBuilder.parameters); + } else { + sqlBuilder.append(value); + } + } + + /** + * Returns the query as a string. + * {@inheritDoc} + */ + @Override + public String toString() { + var stringBuilder = new StringBuilder(); + + var parameterIterator = parameters.iterator(); + + for (int i = 0, n = sqlBuilder.length(); i < n; i++) { + var c = sqlBuilder.charAt(i); + + if (c == '?') { + var parameter = parameterIterator.next(); + + if (parameter == null) { + stringBuilder.append(c); + } else { + stringBuilder.append(':'); + stringBuilder.append(parameter); + } + } else { + stringBuilder.append(c); + } + } + + return stringBuilder.toString(); + } +} diff --git a/core/src/main/java/org/opensearch/sql/utils/QueryBuilder.java b/core/src/main/java/org/opensearch/sql/utils/QueryBuilder.java new file mode 100644 index 0000000000..873d601415 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/utils/QueryBuilder.java @@ -0,0 +1,13 @@ +package org.opensearch.sql.utils; + +/** + * this Query builder interface defines the contract needed to be implemented by any query language builder factory + * @param + */ +public interface QueryBuilder { + + T build(); + + QueryBuilder withName(String name); + +} \ No newline at end of file diff --git a/core/src/test/java/org/opensearch/sql/utils/SQLQueryBuilderTest.java b/core/src/test/java/org/opensearch/sql/utils/SQLQueryBuilderTest.java new file mode 100644 index 0000000000..cb44bf9af6 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/utils/SQLQueryBuilderTest.java @@ -0,0 +1,287 @@ +package org.opensearch.sql.utils; +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +import org.junit.jupiter.api.Test; + +import java.util.AbstractMap; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.opensearch.sql.utils.Builder.allOf; +import static org.opensearch.sql.utils.Builder.and; +import static org.opensearch.sql.utils.Builder.anyOf; +import static org.opensearch.sql.utils.Builder.equalTo; +import static org.opensearch.sql.utils.Builder.exists; +import static org.opensearch.sql.utils.Builder.in; +import static org.opensearch.sql.utils.Builder.notEqualTo; +import static org.opensearch.sql.utils.Builder.notExists; +import static org.opensearch.sql.utils.Builder.notIn; +import static org.opensearch.sql.utils.Builder.or; + +public class SQLQueryBuilderTest { + @Test + public void testSelect() { + var queryBuilder = Builder.select(":a as a", "b", "c", "d") + .from("A") + .join("B").on("A.id = B.id", and("x = 50")) + .leftJoin("C").on("B.id = C.id", and("b = :b")) + .rightJoin("D").on("C.id = D.id", and("c = :c")) + .where("a > 10", or("b < 200", and("d != ?"))) + .orderBy("a", "b") + .limit(10) + .forUpdate() + .union(Builder.select("a", "b", "c", "d").from("C").where("c = :c")); + + assertEquals(Arrays.asList("a", "b", "c", null, "c"), queryBuilder.getParameters()); + + assertEquals("select ? as a, b, c, d from A " + + "join B on A.id = B.id and x = 50 " + + "left join C on B.id = C.id and b = ? " + + "right join D on C.id = D.id and c = ? " + + "where a > 10 or (b < 200 and d != ?) " + + "order by a, b " + + "limit 10 " + + "for update " + + "union select a, b, c, d from C where c = ?", queryBuilder.getSQL()); + } + + @Test + public void testSelectWithSubquery() { + var queryBuilder = Builder.select("*") + .from(Builder.select("a, b, c").from("A"), "a") + .where("d = :d"); + + assertEquals(Arrays.asList("d"), queryBuilder.getParameters()); + + assertEquals("select * from (select a, b, c from A) a where d = ?", queryBuilder.getSQL()); + } + + @Test + public void testSelectWithSubqueryJoin() { + var queryBuilder = Builder.select("*").from("A") + .join(Builder.select("c, d").from("C"), "c").on("c = :c") + .where("d = :d"); + + assertEquals(Arrays.asList("c", "d"), queryBuilder.getParameters()); + + assertEquals("select * from A join (select c, d from C) c on c = ? where d = ?", queryBuilder.getSQL()); + } + + @Test + public void testInsertInto() { + var queryBuilder = Builder.insertInto("A").values(mapOf( + entry("a", 1), + entry("b", true), + entry("c", "hello"), + entry("d", ":d"), + entry("e", "?"), + entry("f", Builder.select("f").from("F").where("g = :g")) + )); + assertEquals(Arrays.asList("d", null, "g"),queryBuilder.getParameters()); + assertEquals("insert into A (a, b, c, d, e, f) values (1, true, 'hello', ?, ?, (select f from F where g = ?))", queryBuilder.getSQL()); + } + + @Test + public void testUpdate() { + var queryBuilder = Builder.update("A").set(mapOf( + entry("a", 1), + entry("b", true), + entry("c", "hello"), + entry("d", ":d"), + entry("e", "?"), + entry("f", Builder.select("f").from("F").where("g = :g"))) + ).where("a is not null"); + + assertEquals(Arrays.asList("d", null, "g"), queryBuilder.getParameters()); + + assertEquals("update A set a = 1, b = true, c = 'hello', d = ?, e = ?, f = (select f from F where g = ?) where a is not null", queryBuilder.getSQL()); + } + + @Test + public void testUpdateWithExpression() { + var queryBuilder = Builder.update("xyz").set(mapOf( + entry("foo", ":a + b") + )).where("c = :d"); + + assertEquals(Arrays.asList("a", "d"), queryBuilder.getParameters()); + + assertEquals("update xyz set foo = ? + b where c = ?", queryBuilder.getSQL()); + } + + @Test + public void testDelete() { + var queryBuilder = Builder.deleteFrom("A").where("a < 150"); + + assertEquals("delete from A where a < 150", queryBuilder.getSQL()); + } + + @Test + public void testConditionalGroups() { + var queryBuilder = Builder.select("*").from("xyz").where(allOf("a = 1", "b = 2", "c = 3"), and(anyOf("d = 4", "e = 5"))); + + assertEquals("select * from xyz where (a = 1 and b = 2 and c = 3) and (d = 4 or e = 5)", queryBuilder.getSQL()); + } + + @Test + public void testEqualToConditional() { + var queryBuilder = Builder.select("*") + .from("A") + .where("b", equalTo( + Builder.select("b").from("B").where("c = :c") + )); + + assertEquals("select * from A where b = (select b from B where c = ?)", queryBuilder.getSQL()); + } + + @Test + public void testNotEqualToConditional() { + var queryBuilder = Builder.select("*") + .from("A") + .where("b", notEqualTo( + Builder.select("b").from("B").where("c = :c") + )); + + assertEquals("select * from A where b != (select b from B where c = ?)", queryBuilder.getSQL()); + } + + @Test + public void testInConditional() { + var queryBuilder = Builder.select("*") + .from("B") + .where("c", in( + Builder.select("c").from("C").where("d = :d") + )); + + assertEquals("select * from B where c in (select c from C where d = ?)", queryBuilder.getSQL()); + } + + @Test + public void testNotInConditional() { + var queryBuilder = Builder.select("*").from("D").where("e", notIn( + Builder.select("e").from("E") + )); + + assertEquals("select * from D where e not in (select e from E)", queryBuilder.getSQL()); + } + + @Test + public void testExistsConditional() { + var queryBuilder = Builder.select("*") + .from("B") + .where(exists( + Builder.select("c").from("C").where("d = :d") + )); + + assertEquals("select * from B where exists (select c from C where d = ?)", queryBuilder.getSQL()); + } + + @Test + public void testNotExistsConditional() { + var queryBuilder = Builder.select("*").from("D").where("e", notExists( + Builder.select("e").from("E") + )); + + assertEquals("select * from D where e not exists (select e from E)", queryBuilder.getSQL()); + } + + @Test + public void testQuotedColon() { + var queryBuilder = Builder.select("*").from("xyz").where("foo = 'a:b:c'"); + + assertEquals("select * from xyz where foo = 'a:b:c'", queryBuilder.getSQL()); + } + + @Test + public void testQuotedQuestionMark() { + var queryBuilder = Builder.select("'?' as q").from("xyz"); + + assertEquals("select '?' as q from xyz", queryBuilder.getSQL()); + } + + @Test + public void testDoubleColon() { + assertThrows(IllegalArgumentException.class, () -> Builder.select("'ab:c'::varchar(16) as abc")); + } + + @Test + public void testEscapedQuotes() { + var queryBuilder = Builder.select("xyz.*", "''':z' as z").from("xyz").where("foo = 'a''b'':c'''", and("bar = ''''")); + + assertEquals("select xyz.*, ''':z' as z from xyz where foo = 'a''b'':c''' and bar = ''''", queryBuilder.getSQL()); + } + + @Test + public void testMissingPredicateParameterName() { + assertThrows(IllegalArgumentException.class, () -> Builder.select("*").from("xyz").where("foo = :")); + } + + @Test + public void testMissingValueParameterName() { + assertThrows(IllegalArgumentException.class, () -> Builder.insertInto("xyz").values(mapOf(entry("foo", ":")) + )); + } + + @Test + public void testExistingSQL() { + var queryBuilder = new Builder("select a, 'b''c:d' as b from foo where bar = :x"); + + assertEquals(Arrays.asList("x"), queryBuilder.getParameters()); + + assertEquals("select a, 'b''c:d' as b from foo where bar = ?", queryBuilder.getSQL()); + } + + @Test + public void testToString() { + var queryBuilder = Builder.select("*").from("xyz").where("foo = :a", and("bar = :b", or("bar = :c"))); + + assertEquals("select * from xyz where foo = :a and (bar = :b or bar = :c)", queryBuilder.toString()); + } + + + @SafeVarargs + public static Map mapOf(Map.Entry... entries) { + if (entries == null) { + throw new IllegalArgumentException(); + } + + Map map = new LinkedHashMap<>(); + + for (var entry : entries) { + map.put(entry.getKey(), entry.getValue()); + } + + return java.util.Collections.unmodifiableMap(map); + } + + /** + * Creates an immutable map entry. + * + * @param The key type. + * @param The value type. + * @param key The entry key. + * @param value The entry value. + * @return An immutable map entry containing the provided key/value pair. + */ + public static Map.Entry entry(K key, V value) { + return new AbstractMap.SimpleImmutableEntry<>(key, value); + } + + +} + diff --git a/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystPlanContext.java b/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystPlanContext.java index e77cdb2072..1dbb110cd0 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystPlanContext.java +++ b/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystPlanContext.java @@ -20,8 +20,10 @@ public class CatalystPlanContext { /** Environment stack for symbol scope management. */ private TypeEnvironment environment; + /** Catalyst evolving logical plan **/ @Getter private LogicalPlan plan; - + + /** NamedExpression contextual parameters **/ @Getter private final List namedParseExpressions; @Getter private final FunctionProperties functionProperties; diff --git a/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystQueryPlanVisitor.java b/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystQueryPlanVisitor.java index 3e7b14ed30..776b1ab7ee 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystQueryPlanVisitor.java +++ b/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystQueryPlanVisitor.java @@ -68,12 +68,10 @@ import static scala.collection.JavaConverters.asScalaBuffer; /** - * Utility class to mask sensitive information in incoming PPL queries. + * Utility class to traverse PPL logical plan and translate it into catalyst logical plan */ public class CatalystQueryPlanVisitor extends AbstractNodeVisitor { - private static final String MASK_LITERAL = "***"; - private final ExpressionAnalyzer expressionAnalyzer; public CatalystQueryPlanVisitor() { @@ -282,7 +280,7 @@ public String analyze(UnresolvedExpression unresolved, CatalystPlanContext conte @Override public String visitLiteral(Literal node, CatalystPlanContext context) { - return MASK_LITERAL; + return node.toString(); } @Override diff --git a/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java b/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java index 615c9baadb..02cc70f8a5 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java @@ -101,12 +101,15 @@ void testSourceWithTableAndConditionPlan() { ); UnresolvedTable table = new UnresolvedTable(asScalaBuffer(of("table")).toSeq(), "source=table ", Option.empty()); // Create a Filter node for the condition 'a = 1' - EqualTo filterCondition = new EqualTo((Expression)projectList.get(0), Literal.create(1,IntegerType)); + EqualTo filterCondition = new EqualTo((Expression) projectList.get(0), Literal.create(1, IntegerType)); LogicalPlan filterPlan = new Filter(filterCondition, table); Assertions.assertEquals(context.getPlan().toString(), filterPlan.toString()); } @Test + /** + * + */ void testSourceWithTableTwoFieldPlan() { Statement plan = plan("source=table | fields A, B", false); CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); @@ -121,6 +124,9 @@ void testSourceWithTableTwoFieldPlan() { } @Test + /** + * Search multiple tables - translated into union call + */ void testSearchWithMultiTablesPlan() { Statement plan = plan("search source = table1, table2 ", false); CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); @@ -135,6 +141,38 @@ void testSearchWithMultiTablesPlan() { Assertions.assertEquals(context.getPlan().toString(), unionPlan.toString()); } + @Test + /** + * Find the top 10 most expensive properties in California, including their addresses, prices, and cities + */ + void testFindTopTenExpensivePropertiesCalifornia() { + Statement plan = plan("source = housing_properties | where state = \"CA\" | fields address, price, city | sort - price | head 10", false); + } + + @Test + /** + * Find the average price per unit of land space for properties in different cities + */ + void testFindAvgPricePerUnitByCity() { + Statement plan = plan("source = housing_properties | where land_space > 0 | eval price_per_land_unit = price / land_space | stats avg(price_per_land_unit) by city", false); + } + + @Test + /** + * Find the houses posted in the last month, how many are still for sale + */ + void testFindHousesForSaleDuringLastMonth() { + Statement plan = plan("search source=housing_properties | where listing_age >= 0 | where listing_age < 30 | stats count() by property_status", false); + } + + @Test + /** + * Find all the houses listed by agency Compass in decreasing price order. Also provide only price, address and agency name information. + */ + void testFindHousesByDecreasePriceWithSpecificFields() { + Statement plan = plan("source = housing_properties | where match( agency_name , \"Compass\" ) | fields address , agency_name , price | sort - price ", false); + } + } From d7337b62b2605852e953e25a387302e4a0af8dbe Mon Sep 17 00:00:00 2001 From: YANGDB Date: Wed, 30 Aug 2023 17:23:24 -0700 Subject: [PATCH 4/7] Adding multiple diverse PPL queries for different test purpose and use cases Signed-off-by: YANGDB --- ...PLToCatalystLogicalPlanTranslatorTest.java | 79 +++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java b/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java index 02cc70f8a5..1b0db4621b 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java @@ -142,6 +142,14 @@ void testSearchWithMultiTablesPlan() { } @Test + /** + * Find What are the average prices for different types of properties? + */ + void testFindAvgPricesForVariousPropertyTypes() { + Statement plan = plan("source = housing_properties | stats avg(price) by property_type", false); + } + + @Test /** * Find the top 10 most expensive properties in California, including their addresses, prices, and cities */ @@ -171,6 +179,77 @@ void testFindHousesForSaleDuringLastMonth() { */ void testFindHousesByDecreasePriceWithSpecificFields() { Statement plan = plan("source = housing_properties | where match( agency_name , \"Compass\" ) | fields address , agency_name , price | sort - price ", false); + } + @Test + /** + * Find details of properties owned by Zillow with at least 3 bedrooms and 2 bathrooms + */ + void testFindHousesByOwnedByZillowWithMinimumOfRoomsWithSpecificFields() { + Statement plan = plan("source = housing_properties | where is_owned_by_zillow = 1 and bedroom_number >= 3 and bathroom_number >= 2 | fields address, price, city, listing_age", false); + } + @Test + /** + * Find which cities in WA state have the largest number of houses for sale? + */ + void testFindCitiesInWAHavingLargeNumbrOfHouseForSale() { + Statement plan = plan("source = housing_properties | where property_status = 'FOR_SALE' and state = 'WA' | stats count() as count by city | sort -count | head", false); + } + @Test + /** + * Find the top 5 referrers for the "/" path in apache access logs? + */ + void testFindTopFiveReferrers() { + Statement plan = plan("source = access_logs | where path = \"/\" | top 5 referer", false); + } + @Test + /** + * Find access paths by status code. How many error responses (status code 400 or higher) are there for each access path in the Apache access logs? + */ + void testFindCountAccessLogsByStatusCode() { + Statement plan = plan("source = access_logs | where status >= 400 | stats count() by path, status", false); + } + @Test + /** + * Find max size of nginx access requests for every 15min. + */ + void testFindMaxSizeOfNginxRequestsWithWindowTimeSpan() { + Statement plan = plan("source = access_logs | stats max(size) by span( request_time , 15m) ", false); + } + @Test + /** + * Find nginx logs with non 2xx status code and url containing 'products' + */ + void testFindNginxLogsWithNon2XXStatusAndProductURL() { + Statement plan = plan("source = sso_logs-nginx-* | where match(http.url, 'products') and http.response.status_code >= \"300\"", false); + } + @Test + /** + * Find What are the details (URL, response status code, timestamp, source address) of events in the nginx logs where the response status code is 400 or higher? + */ + void testFindDetailsOfNginxLogsWithResponseAbove400() { + Statement plan = plan("source = sso_logs-nginx-* | where http.response.status_code >= \"400\" | fields http.url, http.response.status_code, @timestamp, communication.source.address", false); + } + @Test + /** + * Find What are the average and max http response sizes, grouped by request method, for access events in the nginx logs? + */ + void testFindAvgAndMaxHttpResponseSizeGroupedBy() { + Statement plan = plan("source = sso_logs-nginx-* | where event.name = \"access\" | stats avg(http.response.bytes), max(http.response.bytes) by http.request.method", false); + } + @Test + /** + * Find flights from which carrier has the longest average delay for flights over 6k miles? + */ + void testFindFlightsWithCarrierHasLongestAvgDelayWithLongFlights() { + Statement plan = plan("source = opensearch_dashboards_sample_data_flights | where DistanceMiles > 6000 | stats avg(FlightDelayMin) by Carrier | sort -`avg(FlightDelayMin)` | head 1", false); + } + + @Test + /** + * Find What's the average ram usage of windows machines over time aggregated by 1 week? + */ + void testFindAvgRamUsageOfWindowsMachineOverTime() { + Statement plan = plan("source = opensearch_dashboards_sample_data_logs | where match(machine.os, 'win') | stats avg(machine.ram) by span(timestamp,1w)", false); } } From 5a4e59d6ec4c976fa1f3732fc6f8254c40a9b223 Mon Sep 17 00:00:00 2001 From: YANGDB Date: Wed, 30 Aug 2023 17:28:20 -0700 Subject: [PATCH 5/7] Adding multiple diverse PPL queries for different test purpose and use cases Signed-off-by: YANGDB --- ...PLToCatalystLogicalPlanTranslatorTest.java | 74 ++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) diff --git a/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java b/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java index 1b0db4621b..65da64336d 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java @@ -58,6 +58,9 @@ private Statement plan(String query, boolean isExplain) { } @Test + /** + * test simple search with only one table and no explicit fields (defaults to all fields) + */ void testSearchWithTableAllFieldsPlan() { Statement plan = plan("search source = table ", false); CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); @@ -68,6 +71,9 @@ void testSearchWithTableAllFieldsPlan() { } @Test + /** + * test simple search with only one table and no explicit fields (defaults to all fields) + */ void testSourceWithTableAllFieldsPlan() { Statement plan = plan("source = table ", false); CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); @@ -78,6 +84,9 @@ void testSourceWithTableAllFieldsPlan() { } @Test + /** + * test simple search with only one table with one field projected + */ void testSourceWithTableOneFieldPlan() { Statement plan = plan("source=table | fields A", false); CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); @@ -91,6 +100,9 @@ void testSourceWithTableOneFieldPlan() { } @Test + /** + * test simple search with only one table with one field literal filtered + */ void testSourceWithTableAndConditionPlan() { Statement plan = plan("source=t a = 1 ", false); CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); @@ -108,7 +120,7 @@ void testSourceWithTableAndConditionPlan() { @Test /** - * + * test simple search with only one table with two fields projected */ void testSourceWithTableTwoFieldPlan() { Statement plan = plan("source=table | fields A, B", false); @@ -147,6 +159,10 @@ void testSearchWithMultiTablesPlan() { */ void testFindAvgPricesForVariousPropertyTypes() { Statement plan = plan("source = housing_properties | stats avg(price) by property_type", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan, context); + //todo add expected catalyst logical plan & compare + Assertions.assertEquals(false,true); } @Test @@ -155,6 +171,10 @@ void testFindAvgPricesForVariousPropertyTypes() { */ void testFindTopTenExpensivePropertiesCalifornia() { Statement plan = plan("source = housing_properties | where state = \"CA\" | fields address, price, city | sort - price | head 10", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan, context); + //todo add expected catalyst logical plan & compare + Assertions.assertEquals(false,true); } @Test @@ -163,6 +183,10 @@ void testFindTopTenExpensivePropertiesCalifornia() { */ void testFindAvgPricePerUnitByCity() { Statement plan = plan("source = housing_properties | where land_space > 0 | eval price_per_land_unit = price / land_space | stats avg(price_per_land_unit) by city", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan, context); + //todo add expected catalyst logical plan & compare + Assertions.assertEquals(false,true); } @Test @@ -171,6 +195,10 @@ void testFindAvgPricePerUnitByCity() { */ void testFindHousesForSaleDuringLastMonth() { Statement plan = plan("search source=housing_properties | where listing_age >= 0 | where listing_age < 30 | stats count() by property_status", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan, context); + //todo add expected catalyst logical plan & compare + Assertions.assertEquals(false,true); } @Test @@ -179,6 +207,10 @@ void testFindHousesForSaleDuringLastMonth() { */ void testFindHousesByDecreasePriceWithSpecificFields() { Statement plan = plan("source = housing_properties | where match( agency_name , \"Compass\" ) | fields address , agency_name , price | sort - price ", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan, context); + //todo add expected catalyst logical plan & compare + Assertions.assertEquals(false,true); } @Test /** @@ -186,6 +218,10 @@ void testFindHousesByDecreasePriceWithSpecificFields() { */ void testFindHousesByOwnedByZillowWithMinimumOfRoomsWithSpecificFields() { Statement plan = plan("source = housing_properties | where is_owned_by_zillow = 1 and bedroom_number >= 3 and bathroom_number >= 2 | fields address, price, city, listing_age", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan, context); + //todo add expected catalyst logical plan & compare + Assertions.assertEquals(false,true); } @Test /** @@ -193,6 +229,10 @@ void testFindHousesByOwnedByZillowWithMinimumOfRoomsWithSpecificFields() { */ void testFindCitiesInWAHavingLargeNumbrOfHouseForSale() { Statement plan = plan("source = housing_properties | where property_status = 'FOR_SALE' and state = 'WA' | stats count() as count by city | sort -count | head", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan, context); + //todo add expected catalyst logical plan & compare + Assertions.assertEquals(false,true); } @Test /** @@ -200,6 +240,10 @@ void testFindCitiesInWAHavingLargeNumbrOfHouseForSale() { */ void testFindTopFiveReferrers() { Statement plan = plan("source = access_logs | where path = \"/\" | top 5 referer", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan, context); + //todo add expected catalyst logical plan & compare + Assertions.assertEquals(false,true); } @Test /** @@ -207,6 +251,10 @@ void testFindTopFiveReferrers() { */ void testFindCountAccessLogsByStatusCode() { Statement plan = plan("source = access_logs | where status >= 400 | stats count() by path, status", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan, context); + //todo add expected catalyst logical plan & compare + Assertions.assertEquals(false,true); } @Test /** @@ -214,6 +262,10 @@ void testFindCountAccessLogsByStatusCode() { */ void testFindMaxSizeOfNginxRequestsWithWindowTimeSpan() { Statement plan = plan("source = access_logs | stats max(size) by span( request_time , 15m) ", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan, context); + //todo add expected catalyst logical plan & compare + Assertions.assertEquals(false,true); } @Test /** @@ -221,6 +273,10 @@ void testFindMaxSizeOfNginxRequestsWithWindowTimeSpan() { */ void testFindNginxLogsWithNon2XXStatusAndProductURL() { Statement plan = plan("source = sso_logs-nginx-* | where match(http.url, 'products') and http.response.status_code >= \"300\"", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan, context); + //todo add expected catalyst logical plan & compare + Assertions.assertEquals(false,true); } @Test /** @@ -228,6 +284,10 @@ void testFindNginxLogsWithNon2XXStatusAndProductURL() { */ void testFindDetailsOfNginxLogsWithResponseAbove400() { Statement plan = plan("source = sso_logs-nginx-* | where http.response.status_code >= \"400\" | fields http.url, http.response.status_code, @timestamp, communication.source.address", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan, context); + //todo add expected catalyst logical plan & compare + Assertions.assertEquals(false,true); } @Test /** @@ -235,6 +295,10 @@ void testFindDetailsOfNginxLogsWithResponseAbove400() { */ void testFindAvgAndMaxHttpResponseSizeGroupedBy() { Statement plan = plan("source = sso_logs-nginx-* | where event.name = \"access\" | stats avg(http.response.bytes), max(http.response.bytes) by http.request.method", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan, context); + //todo add expected catalyst logical plan & compare + Assertions.assertEquals(false,true); } @Test /** @@ -242,6 +306,10 @@ void testFindAvgAndMaxHttpResponseSizeGroupedBy() { */ void testFindFlightsWithCarrierHasLongestAvgDelayWithLongFlights() { Statement plan = plan("source = opensearch_dashboards_sample_data_flights | where DistanceMiles > 6000 | stats avg(FlightDelayMin) by Carrier | sort -`avg(FlightDelayMin)` | head 1", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan, context); + //todo add expected catalyst logical plan & compare + Assertions.assertEquals(false,true); } @Test @@ -250,6 +318,10 @@ void testFindFlightsWithCarrierHasLongestAvgDelayWithLongFlights() { */ void testFindAvgRamUsageOfWindowsMachineOverTime() { Statement plan = plan("source = opensearch_dashboards_sample_data_logs | where match(machine.os, 'win') | stats avg(machine.ram) by span(timestamp,1w)", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan, context); + //todo add expected catalyst logical plan & compare + Assertions.assertEquals(false,true); } } From 58febc3b571ff35c190d9959d9d21a17c030c400 Mon Sep 17 00:00:00 2001 From: YANGDB Date: Mon, 4 Sep 2023 12:55:12 -0700 Subject: [PATCH 6/7] add ComparatorTransformer skeleton Signed-off-by: YANGDB --- .../spark/ppl/CatalystQueryPlanVisitor.java | 1 + .../sql/spark/ppl/ComparatorTransformer.java | 432 ++++++++++++++++++ ...PLToCatalystLogicalPlanTranslatorTest.java | 39 +- 3 files changed, 469 insertions(+), 3 deletions(-) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/ppl/ComparatorTransformer.java diff --git a/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystQueryPlanVisitor.java b/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystQueryPlanVisitor.java index 776b1ab7ee..1fa64e40d6 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystQueryPlanVisitor.java +++ b/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystQueryPlanVisitor.java @@ -334,6 +334,7 @@ public String visitFunction(Function node, CatalystPlanContext context) { @Override public String visitCompare(Compare node, CatalystPlanContext context) { + String left = analyze(node.getLeft(), context); String right = analyze(node.getRight(), context); return StringUtils.format("%s %s %s", left, node.getOperator(), right); diff --git a/spark/src/main/java/org/opensearch/sql/spark/ppl/ComparatorTransformer.java b/spark/src/main/java/org/opensearch/sql/spark/ppl/ComparatorTransformer.java new file mode 100644 index 0000000000..5d158b1fc8 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/ppl/ComparatorTransformer.java @@ -0,0 +1,432 @@ +package org.opensearch.sql.spark.ppl; + +import org.apache.spark.sql.catalyst.expressions.BinaryComparison; +import org.opensearch.sql.ast.expression.Compare; +import org.opensearch.sql.expression.function.BuiltinFunctionName; + +/** + * Transform the PPL Logical comparator into catalyst comparator + */ +public class ComparatorTransformer { + /** + * expression builder + * @return + */ + public static BinaryComparison comparator(Compare expression) { + if(BuiltinFunctionName.of(expression.getOperator()).isEmpty()) + throw new IllegalStateException("Unexpected value: " + BuiltinFunctionName.of(expression.getOperator())); + + switch (BuiltinFunctionName.of(expression.getOperator()).get()) { + case ABS: + break; + case CEIL: + break; + case CEILING: + break; + case CONV: + break; + case CRC32: + break; + case E: + break; + case EXP: + break; + case EXPM1: + break; + case FLOOR: + break; + case LN: + break; + case LOG: + break; + case LOG10: + break; + case LOG2: + break; + case PI: + break; + case POW: + break; + case POWER: + break; + case RAND: + break; + case RINT: + break; + case ROUND: + break; + case SIGN: + break; + case SIGNUM: + break; + case SINH: + break; + case SQRT: + break; + case CBRT: + break; + case TRUNCATE: + break; + case ACOS: + break; + case ASIN: + break; + case ATAN: + break; + case ATAN2: + break; + case COS: + break; + case COSH: + break; + case COT: + break; + case DEGREES: + break; + case RADIANS: + break; + case SIN: + break; + case TAN: + break; + case ADDDATE: + break; + case ADDTIME: + break; + case CONVERT_TZ: + break; + case DATE: + break; + case DATEDIFF: + break; + case DATETIME: + break; + case DATE_ADD: + break; + case DATE_FORMAT: + break; + case DATE_SUB: + break; + case DAY: + break; + case DAYNAME: + break; + case DAYOFMONTH: + break; + case DAY_OF_MONTH: + break; + case DAYOFWEEK: + break; + case DAYOFYEAR: + break; + case DAY_OF_WEEK: + break; + case DAY_OF_YEAR: + break; + case EXTRACT: + break; + case FROM_DAYS: + break; + case FROM_UNIXTIME: + break; + case GET_FORMAT: + break; + case HOUR: + break; + case HOUR_OF_DAY: + break; + case LAST_DAY: + break; + case MAKEDATE: + break; + case MAKETIME: + break; + case MICROSECOND: + break; + case MINUTE: + break; + case MINUTE_OF_DAY: + break; + case MINUTE_OF_HOUR: + break; + case MONTH: + break; + case MONTH_OF_YEAR: + break; + case MONTHNAME: + break; + case PERIOD_ADD: + break; + case PERIOD_DIFF: + break; + case QUARTER: + break; + case SEC_TO_TIME: + break; + case SECOND: + break; + case SECOND_OF_MINUTE: + break; + case STR_TO_DATE: + break; + case SUBDATE: + break; + case SUBTIME: + break; + case TIME: + break; + case TIMEDIFF: + break; + case TIME_TO_SEC: + break; + case TIMESTAMP: + break; + case TIMESTAMPADD: + break; + case TIMESTAMPDIFF: + break; + case TIME_FORMAT: + break; + case TO_DAYS: + break; + case TO_SECONDS: + break; + case UTC_DATE: + break; + case UTC_TIME: + break; + case UTC_TIMESTAMP: + break; + case UNIX_TIMESTAMP: + break; + case WEEK: + break; + case WEEKDAY: + break; + case WEEKOFYEAR: + break; + case WEEK_OF_YEAR: + break; + case YEAR: + break; + case YEARWEEK: + break; + case NOW: + break; + case CURDATE: + break; + case CURRENT_DATE: + break; + case CURTIME: + break; + case CURRENT_TIME: + break; + case LOCALTIME: + break; + case CURRENT_TIMESTAMP: + break; + case LOCALTIMESTAMP: + break; + case SYSDATE: + break; + case TOSTRING: + break; + case ADD: + break; + case ADDFUNCTION: + break; + case DIVIDE: + break; + case DIVIDEFUNCTION: + break; + case MOD: + break; + case MODULUS: + break; + case MODULUSFUNCTION: + break; + case MULTIPLY: + break; + case MULTIPLYFUNCTION: + break; + case SUBTRACT: + break; + case SUBTRACTFUNCTION: + break; + case AND: + break; + case OR: + break; + case XOR: + break; + case NOT: + break; + case EQUAL: +// return new EqualTo() + break; + case NOTEQUAL: + break; + case LESS: + break; + case LTE: + break; + case GREATER: + break; + case GTE: + break; + case LIKE: + break; + case NOT_LIKE: + break; + case AVG: + break; + case SUM: + break; + case COUNT: + break; + case MIN: + break; + case MAX: + break; + case VARSAMP: + break; + case VARPOP: + break; + case STDDEV_SAMP: + break; + case STDDEV_POP: + break; + case TAKE: + break; + case NESTED: + break; + case ASCII: + break; + case CONCAT: + break; + case CONCAT_WS: + break; + case LEFT: + break; + case LENGTH: + break; + case LOCATE: + break; + case LOWER: + break; + case LTRIM: + break; + case POSITION: + break; + case REGEXP: + break; + case REPLACE: + break; + case REVERSE: + break; + case RIGHT: + break; + case RTRIM: + break; + case STRCMP: + break; + case SUBSTR: + break; + case SUBSTRING: + break; + case TRIM: + break; + case UPPER: + break; + case IS_NULL: + break; + case IS_NOT_NULL: + break; + case IFNULL: + break; + case IF: + break; + case NULLIF: + break; + case ISNULL: + break; + case ROW_NUMBER: + break; + case RANK: + break; + case DENSE_RANK: + break; + case INTERVAL: + break; + case CAST_TO_STRING: + break; + case CAST_TO_BYTE: + break; + case CAST_TO_SHORT: + break; + case CAST_TO_INT: + break; + case CAST_TO_LONG: + break; + case CAST_TO_FLOAT: + break; + case CAST_TO_DOUBLE: + break; + case CAST_TO_BOOLEAN: + break; + case CAST_TO_DATE: + break; + case CAST_TO_TIME: + break; + case CAST_TO_TIMESTAMP: + break; + case CAST_TO_DATETIME: + break; + case TYPEOF: + break; + case MATCH: + break; + case SIMPLE_QUERY_STRING: + break; + case MATCH_PHRASE: + break; + case MATCHPHRASE: + break; + case MATCHPHRASEQUERY: + break; + case QUERY_STRING: + break; + case MATCH_BOOL_PREFIX: + break; + case HIGHLIGHT: + break; + case MATCH_PHRASE_PREFIX: + break; + case SCORE: + break; + case SCOREQUERY: + break; + case SCORE_QUERY: + break; + case QUERY: + break; + case MATCH_QUERY: + break; + case MATCHQUERY: + break; + case MULTI_MATCH: + break; + case MULTIMATCH: + break; + case MULTIMATCHQUERY: + break; + case WILDCARDQUERY: + break; + case WILDCARD_QUERY: + break; + default: + return null; + } + return null; + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java b/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java index 65da64336d..13445cebcc 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java @@ -29,7 +29,6 @@ import scala.Option; import scala.collection.JavaConverters; import scala.collection.Seq; -import scala.reflect.internal.Trees; import java.util.Arrays; import java.util.Collections; @@ -41,9 +40,13 @@ public class PPLToCatalystLogicalPlanTranslatorTest { + private PPLSyntaxParser parser = new PPLSyntaxParser(); @Mock private SparkClient sparkClient; +// private Catalog catalog = mock(Catalog.class); +// private SQLConf sqlConf = mock(SQLConf.class); + @Mock private LogicalProject logicalProject; @@ -57,6 +60,14 @@ private Statement plan(String query, boolean isExplain) { return builder.visit(parser.parse(query)); } + /* + private LogicalPlan analyze(LogicalPlan logicalPlan) { + // Analyze plan + Analyzer analyzer = new Analyzer(spark.sessionState().catalog(), spark.sessionState().conf()); + return analyzer.execute(logicalPlan); + } + */ + @Test /** * test simple search with only one table and no explicit fields (defaults to all fields) @@ -107,6 +118,28 @@ void testSourceWithTableAndConditionPlan() { Statement plan = plan("source=t a = 1 ", false); CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); planVisitor.visit(plan, context); + + Seq projectList = JavaConverters.asScalaBuffer(Collections.singletonList((Object) UnresolvedStar$.MODULE$.apply(Option.>empty()))).toSeq(); + // Create a Project node for fields A and B + List filterField = Arrays.asList( + UnresolvedAttribute$.MODULE$.apply(JavaConverters.asScalaBuffer(Collections.singletonList("a"))) + ); + UnresolvedTable table = new UnresolvedTable(asScalaBuffer(of("table")).toSeq(), "source=table ", Option.empty()); + // Create a Filter node for the condition 'a = 1' + EqualTo filterCondition = new EqualTo((Expression) filterField.get(0), Literal.create(1, IntegerType)); + LogicalPlan filterPlan = new Filter(filterCondition, table); + Project project = new Project((Seq) projectList, filterPlan); + Assertions.assertEquals(context.getPlan().toString(), project.toString()); + } + + @Test + /** + * test simple search with only one table with one field literal filtered and one field projected + */ + void testSourceWithTableAndConditionWithOneFieldPlan() { + Statement plan = plan("source=t a = 1 | fields a", false); + CatalystQueryPlanVisitor planVisitor = new CatalystQueryPlanVisitor(); + planVisitor.visit(plan, context); // Create a Project node for fields A and B List projectList = Arrays.asList( UnresolvedAttribute$.MODULE$.apply(JavaConverters.asScalaBuffer(Collections.singletonList("a"))) @@ -115,7 +148,8 @@ void testSourceWithTableAndConditionPlan() { // Create a Filter node for the condition 'a = 1' EqualTo filterCondition = new EqualTo((Expression) projectList.get(0), Literal.create(1, IntegerType)); LogicalPlan filterPlan = new Filter(filterCondition, table); - Assertions.assertEquals(context.getPlan().toString(), filterPlan.toString()); + Project project = new Project(asScalaBuffer(projectList).toSeq(), filterPlan); + Assertions.assertEquals(context.getPlan().toString(), project.toString()); } @Test @@ -323,7 +357,6 @@ void testFindAvgRamUsageOfWindowsMachineOverTime() { //todo add expected catalyst logical plan & compare Assertions.assertEquals(false,true); } - } From acaf59cb8bbb9c38810e009e28f74e3fb64e6af9 Mon Sep 17 00:00:00 2001 From: YANGDB Date: Tue, 5 Sep 2023 15:58:31 -0700 Subject: [PATCH 7/7] support basic filter plans with basic literals add ComparatorTransformer include initial data-type transformations Signed-off-by: YANGDB --- .../sql/spark/ppl/CatalystPlanContext.java | 6 +++-- .../spark/ppl/CatalystQueryPlanVisitor.java | 8 +++++- .../sql/spark/ppl/ComparatorTransformer.java | 22 +++++++++++----- .../sql/spark/ppl/DataTypeTransformer.java | 26 +++++++++++++++++++ ...PLToCatalystLogicalPlanTranslatorTest.java | 4 +-- 5 files changed, 54 insertions(+), 12 deletions(-) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/ppl/DataTypeTransformer.java diff --git a/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystPlanContext.java b/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystPlanContext.java index 1dbb110cd0..c436780822 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystPlanContext.java +++ b/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystPlanContext.java @@ -6,6 +6,7 @@ package org.opensearch.sql.spark.ppl; import lombok.Getter; +import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.opensearch.sql.analysis.TypeEnvironment; @@ -14,6 +15,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; +import java.util.Stack; /** The context used for Catalyst logical plan. */ public class CatalystPlanContext { @@ -24,7 +26,7 @@ public class CatalystPlanContext { @Getter private LogicalPlan plan; /** NamedExpression contextual parameters **/ - @Getter private final List namedParseExpressions; + @Getter private final Stack namedParseExpressions; @Getter private final FunctionProperties functionProperties; @@ -39,7 +41,7 @@ public CatalystPlanContext() { */ public CatalystPlanContext(TypeEnvironment environment) { this.environment = environment; - this.namedParseExpressions = new ArrayList<>(); + this.namedParseExpressions = new Stack<>(); this.functionProperties = new FunctionProperties(); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystQueryPlanVisitor.java b/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystQueryPlanVisitor.java index 1fa64e40d6..1d45810798 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystQueryPlanVisitor.java +++ b/spark/src/main/java/org/opensearch/sql/spark/ppl/CatalystQueryPlanVisitor.java @@ -13,6 +13,8 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.analysis.UnresolvedTable; +import org.apache.spark.sql.catalyst.expressions.BinaryComparison; +import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.AggregateFunction; @@ -64,6 +66,7 @@ import java.util.stream.Collectors; import static java.util.List.of; +import static org.opensearch.sql.spark.ppl.DataTypeTransformer.translate; import static scala.Option.empty; import static scala.collection.JavaConverters.asScalaBuffer; @@ -120,6 +123,8 @@ public String visitTableFunction(TableFunction node, CatalystPlanContext context public String visitFilter(Filter node, CatalystPlanContext context) { String child = node.getChild().get(0).accept(this, context); String condition = visitExpression(node.getCondition(),context); + Expression innerCondition = context.getNamedParseExpressions().pop(); + context.plan(new org.apache.spark.sql.catalyst.plans.logical.Filter(innerCondition,context.getPlan())); return StringUtils.format("%s | where %s", child, condition); } @@ -280,6 +285,7 @@ public String analyze(UnresolvedExpression unresolved, CatalystPlanContext conte @Override public String visitLiteral(Literal node, CatalystPlanContext context) { + context.getNamedParseExpressions().add(new org.apache.spark.sql.catalyst.expressions.Literal(node.getValue(),translate(node.getType()))); return node.toString(); } @@ -334,9 +340,9 @@ public String visitFunction(Function node, CatalystPlanContext context) { @Override public String visitCompare(Compare node, CatalystPlanContext context) { - String left = analyze(node.getLeft(), context); String right = analyze(node.getRight(), context); + context.getNamedParseExpressions().add(ComparatorTransformer.comparator(node, context)); return StringUtils.format("%s %s %s", left, node.getOperator(), right); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/ppl/ComparatorTransformer.java b/spark/src/main/java/org/opensearch/sql/spark/ppl/ComparatorTransformer.java index 5d158b1fc8..6189e95fac 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/ppl/ComparatorTransformer.java +++ b/spark/src/main/java/org/opensearch/sql/spark/ppl/ComparatorTransformer.java @@ -1,21 +1,30 @@ package org.opensearch.sql.spark.ppl; import org.apache.spark.sql.catalyst.expressions.BinaryComparison; +import org.apache.spark.sql.catalyst.expressions.EqualTo; +import org.apache.spark.sql.catalyst.expressions.Expression; import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.expression.function.BuiltinFunctionName; /** * Transform the PPL Logical comparator into catalyst comparator */ -public class ComparatorTransformer { +public interface ComparatorTransformer { /** - * expression builder + * comparator expression builder building a catalyst binary comparator from PPL's compare logical step * @return */ - public static BinaryComparison comparator(Compare expression) { - if(BuiltinFunctionName.of(expression.getOperator()).isEmpty()) + static BinaryComparison comparator(Compare expression, CatalystPlanContext context) { + if (BuiltinFunctionName.of(expression.getOperator()).isEmpty()) throw new IllegalStateException("Unexpected value: " + BuiltinFunctionName.of(expression.getOperator())); - + + if (context.getNamedParseExpressions().isEmpty()) { + throw new IllegalStateException("Unexpected value: No operands found in expression"); + } + + Expression right = context.getNamedParseExpressions().pop(); + Expression left = context.getNamedParseExpressions().isEmpty() ? null : context.getNamedParseExpressions().pop(); + switch (BuiltinFunctionName.of(expression.getOperator()).get()) { case ABS: break; @@ -262,8 +271,7 @@ public static BinaryComparison comparator(Compare expression) { case NOT: break; case EQUAL: -// return new EqualTo() - break; + return new EqualTo(left,right); case NOTEQUAL: break; case LESS: diff --git a/spark/src/main/java/org/opensearch/sql/spark/ppl/DataTypeTransformer.java b/spark/src/main/java/org/opensearch/sql/spark/ppl/DataTypeTransformer.java new file mode 100644 index 0000000000..0542153cb3 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/ppl/DataTypeTransformer.java @@ -0,0 +1,26 @@ +package org.opensearch.sql.spark.ppl; + + +import org.apache.spark.sql.types.ByteType$; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType$; +import org.apache.spark.sql.types.IntegerType$; +import org.apache.spark.sql.types.StringType$; + +/** + * translate the PPL ast expressions data-types into catalyst data-types + */ +public interface DataTypeTransformer { + static DataType translate(org.opensearch.sql.ast.expression.DataType source) { + switch (source.getCoreType()) { + case TIME: + return DateType$.MODULE$; + case INTEGER: + return IntegerType$.MODULE$; + case BYTE: + return ByteType$.MODULE$; + default: + return StringType$.MODULE$; + } + } +} \ No newline at end of file diff --git a/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java b/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java index 13445cebcc..09781ff475 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/ppl/PPLToCatalystLogicalPlanTranslatorTest.java @@ -124,7 +124,7 @@ void testSourceWithTableAndConditionPlan() { List filterField = Arrays.asList( UnresolvedAttribute$.MODULE$.apply(JavaConverters.asScalaBuffer(Collections.singletonList("a"))) ); - UnresolvedTable table = new UnresolvedTable(asScalaBuffer(of("table")).toSeq(), "source=table ", Option.empty()); + UnresolvedTable table = new UnresolvedTable(asScalaBuffer(of("t")).toSeq(), "source=t", Option.empty()); // Create a Filter node for the condition 'a = 1' EqualTo filterCondition = new EqualTo((Expression) filterField.get(0), Literal.create(1, IntegerType)); LogicalPlan filterPlan = new Filter(filterCondition, table); @@ -144,7 +144,7 @@ void testSourceWithTableAndConditionWithOneFieldPlan() { List projectList = Arrays.asList( UnresolvedAttribute$.MODULE$.apply(JavaConverters.asScalaBuffer(Collections.singletonList("a"))) ); - UnresolvedTable table = new UnresolvedTable(asScalaBuffer(of("table")).toSeq(), "source=table ", Option.empty()); + UnresolvedTable table = new UnresolvedTable(asScalaBuffer(of("t")).toSeq(), "source=t ", Option.empty()); // Create a Filter node for the condition 'a = 1' EqualTo filterCondition = new EqualTo((Expression) projectList.get(0), Literal.create(1, IntegerType)); LogicalPlan filterPlan = new Filter(filterCondition, table);