From 355a489df14d87cdc6950ed7fb48b582259a0dce Mon Sep 17 00:00:00 2001 From: Tiewei Fang <43782773+BePPPower@users.noreply.github.com> Date: Thu, 16 May 2024 20:43:40 +0800 Subject: [PATCH] [Enhencement](trino-connector) The trino-connector catalog supports pushdown predicates to the connector (#34422) The Trino connector SPI provides interfaces for filter pushdown, which can be utilized to push predicates down to the connector layer. --- .../TrinoConnectorJniScanner.java | 31 +- .../TrinoConnectorPluginLoader.java | 6 + .../source/PaimonPredicateConverter.java | 2 +- .../TrinoConnectorPluginLoader.java | 6 + .../TrinoConnectorPredicateConverter.java | 336 ++++++++ .../source/TrinoConnectorScanNode.java | 124 ++- .../source/TrinoConnectorSource.java | 12 +- .../TrinoConnectorPredicateTest.java | 735 ++++++++++++++++++ 8 files changed, 1218 insertions(+), 34 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/datasource/trinoconnector/source/TrinoConnectorPredicateConverter.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/datasource/trinoconnector/TrinoConnectorPredicateTest.java diff --git a/fe/be-java-extensions/trino-connector-scanner/src/main/java/org/apache/doris/trinoconnector/TrinoConnectorJniScanner.java b/fe/be-java-extensions/trino-connector-scanner/src/main/java/org/apache/doris/trinoconnector/TrinoConnectorJniScanner.java index 3209213ee6931d..a9d857693a6ab5 100644 --- a/fe/be-java-extensions/trino-connector-scanner/src/main/java/org/apache/doris/trinoconnector/TrinoConnectorJniScanner.java +++ b/fe/be-java-extensions/trino-connector-scanner/src/main/java/org/apache/doris/trinoconnector/TrinoConnectorJniScanner.java @@ -31,6 +31,7 @@ import io.trino.Session; import io.trino.SystemSessionProperties; import io.trino.SystemSessionPropertiesProvider; +import io.trino.block.BlockJsonSerde; import io.trino.client.ClientCapabilities; import io.trino.connector.CatalogServiceProviderModule; import io.trino.execution.DynamicFilterConfig; @@ -40,8 +41,10 @@ import io.trino.execution.scheduler.NodeSchedulerConfig; import io.trino.memory.MemoryManagerConfig; import io.trino.memory.NodeMemoryConfig; +import io.trino.metadata.BlockEncodingManager; import io.trino.metadata.HandleJsonModule; import io.trino.metadata.HandleResolver; +import io.trino.metadata.InternalBlockEncodingSerde; import io.trino.metadata.SessionPropertyManager; import io.trino.plugin.base.TypeDeserializer; import io.trino.spi.Page; @@ -134,6 +137,14 @@ public TrinoConnectorJniScanner(int batchSize, Map params) { connectorColumnMetadataString = params.get("trino_connector_column_metadata"); connectorPredicateString = params.get("trino_connector_predicate"); connectorTrascationHandleString = params.get("trino_connector_trascation_handle"); + if (LOG.isDebugEnabled()) { + LOG.debug("TrinoConnectorJniScanner connectorSplitString = " + connectorSplitString + + " ; connectorTableHandleString = " + connectorTableHandleString + + " ; connectorColumnHandleString = " + connectorColumnHandleString + + " ; connectorColumnMetadataString = " + connectorColumnMetadataString + + " ; connectorPredicateString = " + connectorPredicateString + + " ; connectorTrascationHandleString = " + connectorTrascationHandleString); + } trinoConnectorOptionParams = params.entrySet().stream() @@ -225,7 +236,7 @@ private ConnectorPageSourceProvider getConnectorPageSourceProvider() { Objects.requireNonNull(connectorPageSourceProvider, String.format("Connector '%s' returned a null page source provider", catalogNameString)); } catch (UnsupportedOperationException ignored) { - LOG.warn("exception when getPageSourceProvider: " + ignored.getMessage()); + LOG.debug("exception when getPageSourceProvider: " + ignored.getMessage()); } try { @@ -238,15 +249,13 @@ private ConnectorPageSourceProvider getConnectorPageSourceProvider() { } connectorPageSourceProvider = new RecordPageSourceProvider(connectorRecordSetProvider); } catch (UnsupportedOperationException ignored) { - LOG.warn("exception when getRecordSetProvider: " + ignored.getMessage()); + LOG.debug("exception when getRecordSetProvider: " + ignored.getMessage()); } return connectorPageSourceProvider; } private ObjectMapperProvider generateObjectMapperProvider() { - TypeManager typeManager = new InternalTypeManager( - TrinoConnectorPluginLoader.getTrinoConnectorPluginManager().getTypeRegistry()); ObjectMapperProvider objectMapperProvider = new ObjectMapperProvider(); Set modules = new HashSet(); modules.add(HandleJsonModule.tableHandleModule(handleResolver)); @@ -259,8 +268,15 @@ private ObjectMapperProvider generateObjectMapperProvider() { // modules.add(HandleJsonModule.indexHandleModule(handleResolver)); // modules.add(HandleJsonModule.partitioningHandleModule(handleResolver)); objectMapperProvider.setModules(modules); - objectMapperProvider.setJsonDeserializers( - ImmutableMap.of(io.trino.spi.type.Type.class, new TypeDeserializer(typeManager))); + + // set json deserializers + TypeManager typeManager = new InternalTypeManager( + TrinoConnectorPluginLoader.getTrinoConnectorPluginManager().getTypeRegistry()); + InternalBlockEncodingSerde blockEncodingSerde = new InternalBlockEncodingSerde(new BlockEncodingManager(), + typeManager); + objectMapperProvider.setJsonDeserializers(ImmutableMap.of( + io.trino.spi.type.Type.class, new TypeDeserializer(typeManager), + Block.class, new BlockJsonSerde.Deserializer(blockEncodingSerde))); return objectMapperProvider; } @@ -320,9 +336,6 @@ private void parseRequiredTypes() { trinoTypeList.add(columnMetadataList.get(index).getType()); String hiveType = TrinoTypeToHiveTypeTranslator.fromTrinoTypeToHiveType(trinoTypeList.get(i)); columnTypes[i] = ColumnType.parseType(fields[i], hiveType); - - // LOG.info(String.format("Trino type: [%s], hive type: [%s], columnTypes: [%s].", - // trinoTypeList.get(i), hiveType, columnTypes[i])); } super.types = columnTypes; } diff --git a/fe/be-java-extensions/trino-connector-scanner/src/main/java/org/apache/doris/trinoconnector/TrinoConnectorPluginLoader.java b/fe/be-java-extensions/trino-connector-scanner/src/main/java/org/apache/doris/trinoconnector/TrinoConnectorPluginLoader.java index 67041895021282..472260aa35a9f0 100644 --- a/fe/be-java-extensions/trino-connector-scanner/src/main/java/org/apache/doris/trinoconnector/TrinoConnectorPluginLoader.java +++ b/fe/be-java-extensions/trino-connector-scanner/src/main/java/org/apache/doris/trinoconnector/TrinoConnectorPluginLoader.java @@ -36,11 +36,17 @@ import java.util.logging.Level; import java.util.logging.SimpleFormatter; +// Noninstancetiable utility class public class TrinoConnectorPluginLoader { private static final Logger LOG = LogManager.getLogger(TrinoConnectorPluginLoader.class); private static String pluginsDir = EnvUtils.getDorisHome() + "/connectors"; + // Suppress default constructor for noninstantiability + private TrinoConnectorPluginLoader() { + throw new AssertionError(); + } + private static class TrinoConnectorPluginLoad { private static FeaturesConfig featuresConfig = new FeaturesConfig(); private static TrinoConnectorPluginManager trinoConnectorPluginManager; diff --git a/fe/fe-core/src/main/java/org/apache/doris/datasource/paimon/source/PaimonPredicateConverter.java b/fe/fe-core/src/main/java/org/apache/doris/datasource/paimon/source/PaimonPredicateConverter.java index 9e65d9714713b8..848e419b2fbcb4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/datasource/paimon/source/PaimonPredicateConverter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/datasource/paimon/source/PaimonPredicateConverter.java @@ -60,7 +60,7 @@ public List convertToPaimonExpr(List conjuncts) { return list; } - public Predicate convertToPaimonExpr(Expr dorisExpr) { + private Predicate convertToPaimonExpr(Expr dorisExpr) { if (dorisExpr == null) { return null; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/datasource/trinoconnector/TrinoConnectorPluginLoader.java b/fe/fe-core/src/main/java/org/apache/doris/datasource/trinoconnector/TrinoConnectorPluginLoader.java index 1e08c9effcb00f..c846e2edf4297c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/datasource/trinoconnector/TrinoConnectorPluginLoader.java +++ b/fe/fe-core/src/main/java/org/apache/doris/datasource/trinoconnector/TrinoConnectorPluginLoader.java @@ -36,9 +36,15 @@ import java.util.logging.Level; import java.util.logging.SimpleFormatter; +// Noninstancetiable utility class public class TrinoConnectorPluginLoader { private static final Logger LOG = LogManager.getLogger(TrinoConnectorPluginLoader.class); + // Suppress default constructor for noninstantiability + private TrinoConnectorPluginLoader() { + throw new AssertionError(); + } + private static class TrinoConnectorPluginLoad { private static FeaturesConfig featuresConfig = new FeaturesConfig(); private static TypeOperators typeOperators = new TypeOperators(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/datasource/trinoconnector/source/TrinoConnectorPredicateConverter.java b/fe/fe-core/src/main/java/org/apache/doris/datasource/trinoconnector/source/TrinoConnectorPredicateConverter.java new file mode 100644 index 00000000000000..9916f5af8e84bd --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/datasource/trinoconnector/source/TrinoConnectorPredicateConverter.java @@ -0,0 +1,336 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.datasource.trinoconnector.source; + +import org.apache.doris.analysis.BinaryPredicate; +import org.apache.doris.analysis.CastExpr; +import org.apache.doris.analysis.CompoundPredicate; +import org.apache.doris.analysis.DateLiteral; +import org.apache.doris.analysis.DecimalLiteral; +import org.apache.doris.analysis.Expr; +import org.apache.doris.analysis.InPredicate; +import org.apache.doris.analysis.IsNullPredicate; +import org.apache.doris.analysis.LiteralExpr; +import org.apache.doris.analysis.NullLiteral; +import org.apache.doris.analysis.SlotRef; +import org.apache.doris.common.AnalysisException; +import org.apache.doris.common.util.TimeUtils; +import org.apache.doris.thrift.TExprOpcode; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import io.airlift.slice.Slices; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ColumnMetadata; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.Range; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.predicate.ValueSet; +import io.trino.spi.type.Int128; +import io.trino.spi.type.LongTimestamp; +import io.trino.spi.type.LongTimestampWithTimeZone; +import io.trino.spi.type.TimeZoneKey; +import io.trino.spi.type.Type; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.math.BigDecimal; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.TimeZone; + + +public class TrinoConnectorPredicateConverter { + private static final Logger LOG = LogManager.getLogger(TrinoConnectorPredicateConverter.class); + private static final String EPOCH_DATE = "1970-01-01"; + private static final String GMT = "GMT"; + private final Map trinoConnectorColumnHandleMap; + + private final Map trinoConnectorColumnMetadataMap; + + public TrinoConnectorPredicateConverter(Map columnHandleMap, + Map columnMetadataMap) { + this.trinoConnectorColumnHandleMap = columnHandleMap; + this.trinoConnectorColumnMetadataMap = columnMetadataMap; + } + + public TupleDomain convertExprToTrinoTupleDomain(Expr predicate) throws AnalysisException { + if (predicate instanceof CompoundPredicate) { + return compoundPredicateConverter((CompoundPredicate) predicate); + } else if (predicate instanceof InPredicate) { + return inPredicateConverter((InPredicate) predicate); + } else if (predicate instanceof BinaryPredicate) { + return binaryPredicateConverter((BinaryPredicate) predicate); + } else if (predicate instanceof IsNullPredicate) { + return isNullPredicateConverter((IsNullPredicate) predicate); + } else { + throw new AnalysisException("Do not support convert predicate: [" + predicate + "]."); + } + } + + private TupleDomain compoundPredicateConverter(CompoundPredicate compoundPredicate) + throws AnalysisException { + switch (compoundPredicate.getOp()) { + case AND: { + TupleDomain left = null; + TupleDomain right = null; + try { + left = convertExprToTrinoTupleDomain(compoundPredicate.getChild(0)); + } catch (AnalysisException e) { + LOG.warn("left predicate of compund predicate failed, exception: " + e.getMessage()); + } + try { + right = convertExprToTrinoTupleDomain(compoundPredicate.getChild(1)); + } catch (AnalysisException e) { + LOG.warn("right predicate of compound predicate failed, exception: " + e.getMessage()); + } + if (left != null && right != null) { + return left.intersect(right); + } else if (left != null) { + return left; + } else if (right != null) { + return right; + } + throw new AnalysisException("Can not convert both sides of compound predicate [" + + compoundPredicate.getOp() + "] to TupleDomain."); + } + case OR: { + TupleDomain left = convertExprToTrinoTupleDomain(compoundPredicate.getChild(0)); + TupleDomain right = convertExprToTrinoTupleDomain(compoundPredicate.getChild(1)); + return TupleDomain.columnWiseUnion(left, right); + } + case NOT: + default: + throw new AnalysisException("Do not support convert compound predicate [" + compoundPredicate.getOp() + + "] to TupleDomain."); + } + } + + private TupleDomain inPredicateConverter(InPredicate predicate) throws AnalysisException { + // Make sure the col slot is always first + SlotRef slotRef = convertExprToSlotRef(predicate.getChild(0)); + if (slotRef == null) { + throw new AnalysisException("slotRef is null in inPredicateConverter."); + } + String colName = slotRef.getColumnName(); + Type type = trinoConnectorColumnMetadataMap.get(colName).getType(); + List ranges = Lists.newArrayList(); + for (int i = 1; i < predicate.getChildren().size(); i++) { + LiteralExpr literalExpr = convertExprToLiteral(predicate.getChild(i)); + if (literalExpr == null) { + throw new AnalysisException("literalExpr of InPredicate's children is null in inPredicateConverter."); + } + ranges.add(Range.equal(type, convertLiteralToDomainValues(type.getClass(), literalExpr))); + } + + Domain domain = predicate.isNotIn() + ? Domain.create(ValueSet.all(type).subtract(ValueSet.ofRanges(ranges)), false) + : Domain.create(ValueSet.ofRanges(ranges), false); + TupleDomain tupleDomain = TupleDomain.withColumnDomains( + ImmutableMap.of(trinoConnectorColumnHandleMap.get(colName), domain)); + return tupleDomain; + } + + private TupleDomain binaryPredicateConverter(BinaryPredicate predicate) throws AnalysisException { + // Make sure the col slot is always first + SlotRef slotRef = convertExprToSlotRef(predicate.getChild(0)); + if (slotRef == null) { + throw new AnalysisException("slotRef is null in binaryPredicateConverter."); + } + LiteralExpr literalExpr = convertExprToLiteral(predicate.getChild(1)); + // literalExpr == null means predicate.getChild(1) is not a LiteralExpr or CastExpr + // such as 'where A.a < A.b',predicate.getChild(1) is SlotRef + if (literalExpr == null) { + throw new AnalysisException("literalExpr of BinaryPredicate's child is null in binaryPredicateConverter."); + } + + String colName = slotRef.getColumnName(); + Type type = trinoConnectorColumnMetadataMap.get(colName).getType(); + Domain domain = null; + TExprOpcode opcode = predicate.getOpcode(); + switch (opcode) { + case EQ: + domain = Domain.create(ValueSet.ofRanges(Range.equal(type, + convertLiteralToDomainValues(type.getClass(), literalExpr))), false); + break; + case EQ_FOR_NULL: { + if (literalExpr instanceof NullLiteral) { + domain = Domain.onlyNull(type); + } else { + domain = Domain.create(ValueSet.ofRanges(Range.equal(type, + convertLiteralToDomainValues(type.getClass(), literalExpr))), false); + } + break; + } + case NE: + domain = Domain.create(ValueSet.all(type).subtract(ValueSet.ofRanges(Range.equal(type, + convertLiteralToDomainValues(type.getClass(), literalExpr)))), false); + break; + case LT: + domain = Domain.create(ValueSet.ofRanges(Range.lessThan(type, + convertLiteralToDomainValues(type.getClass(), literalExpr))), false); + break; + case LE: + domain = Domain.create(ValueSet.ofRanges(Range.lessThanOrEqual(type, + convertLiteralToDomainValues(type.getClass(), literalExpr))), false); + break; + case GT: + domain = Domain.create(ValueSet.ofRanges(Range.greaterThan(type, + convertLiteralToDomainValues(type.getClass(), literalExpr))), false); + break; + case GE: + domain = Domain.create(ValueSet.ofRanges(Range.greaterThanOrEqual(type, + convertLiteralToDomainValues(type.getClass(), literalExpr))), false); + break; + case INVALID_OPCODE: + default: + throw new AnalysisException("Do not support opcode [" + opcode + "] in binaryPredicateConverter."); + } + return TupleDomain.withColumnDomains(ImmutableMap.of(trinoConnectorColumnHandleMap.get(colName), domain)); + } + + private TupleDomain isNullPredicateConverter(IsNullPredicate predicate) throws AnalysisException { + Objects.requireNonNull(predicate.getChild(0), "The first child of IsNullPredicate is null."); + SlotRef slotRef = convertExprToSlotRef(predicate.getChild(0)); + if (slotRef == null) { + throw new AnalysisException("slotRef is null in IsNullPredicate."); + } + String colName = slotRef.getColumnName(); + Type type = trinoConnectorColumnMetadataMap.get(colName).getType(); + if (predicate.isNotNull()) { + return TupleDomain.withColumnDomains( + ImmutableMap.of(trinoConnectorColumnHandleMap.get(colName), Domain.notNull(type))); + } + return TupleDomain.withColumnDomains( + ImmutableMap.of(trinoConnectorColumnHandleMap.get(colName), Domain.onlyNull(type))); + } + + /* Since different Trino types have different data formats stored in their Range, + we need to convert the data format stored in Doris's LiteralExpr to the corresponding Java data type + which can be recognized by the Trino Type Range. + The correspondence between different Trino types and the Java data types stored in their Range is as follows: + + Trino Type Java Type + + BooleanType boolean + TinyintType long + SmallintType long + IntegerType long + BigintType long + RealType long + ShortDecimalType long + LongDecimalType io.trino.spi.type.Int128 + CharType io.airlift.slice.Slice + VarbinaryType io.airlift.slice.Slice + VarcharType io.airlift.slice.Slice + DateType long + DoubleType double + TimeType long + ShortTimestampType long + LongTimestampType io.trino.spi.type.LongTimestamp + ShortTimestampWithTimeZoneType long + LongTimestampWithTimeZoneType io.trino.spi.type.LongTimestampWithTimeZone + ArrayType io.trino.spi.block.Block + MapType io.trino.spi.block.SqlMap + RowType io.trino.spi.block.SqlRow*/ + private Object convertLiteralToDomainValues(Class type, LiteralExpr literalExpr) + throws AnalysisException { + switch (type.getSimpleName()) { + case "BooleanType": + return literalExpr.getRealValue(); + case "TinyintType": + case "SmallintType": + case "IntegerType": + case "BigintType": + return literalExpr.getLongValue(); + case "RealType": + return (long) Float.floatToIntBits((float) literalExpr.getDoubleValue()); + case "DoubleType": + return literalExpr.getDoubleValue(); + case "ShortDecimalType": { + BigDecimal value = (BigDecimal) literalExpr.getRealValue(); + BigDecimal tmpValue = new BigDecimal(Math.pow(10, DecimalLiteral.getBigDecimalScale(value))); + value = value.multiply(tmpValue); + return value.longValue(); + } + case "LongDecimalType": { + BigDecimal value = (BigDecimal) literalExpr.getRealValue(); + BigDecimal tmpValue = new BigDecimal(Math.pow(10, DecimalLiteral.getBigDecimalScale(value))); + value = value.multiply(tmpValue); + return Int128.valueOf(value.toBigIntegerExact()); + } + case "CharType": + case "VarbinaryType": + case "VarcharType": + return Slices.utf8Slice((String) literalExpr.getRealValue()); + case "DateType": + return ((DateLiteral) literalExpr).daynr() - new DateLiteral(EPOCH_DATE).daynr(); + case "ShortTimestampType": { + DateLiteral dateLiteral = (DateLiteral) literalExpr; + return dateLiteral.unixTimestamp(TimeZone.getTimeZone(GMT)) * 1000 + + dateLiteral.getMicrosecond(); + } + case "LongTimestampType": { + DateLiteral dateLiteral = (DateLiteral) literalExpr; + long epochMicros = dateLiteral.unixTimestamp(TimeZone.getTimeZone(GMT)) * 1000 + + dateLiteral.getMicrosecond(); + return new LongTimestamp(epochMicros, 0); + } + case "LongTimestampWithTimeZoneType": { + DateLiteral dateLiteral = (DateLiteral) literalExpr; + long epochMillis = dateLiteral.unixTimestamp(TimeUtils.getTimeZone()); + int picosOfMilli = (int) dateLiteral.getMicrosecond() * 1000000; + TimeZoneKey timeZoneKey = TimeZoneKey.getTimeZoneKey(TimeUtils.getTimeZone().toZoneId().toString()); + return LongTimestampWithTimeZone.fromEpochMillisAndFraction(epochMillis, picosOfMilli, timeZoneKey); + } + case "ShortTimestampWithTimeZoneType": + case "TimeType": + case "ArrayType": + case "MapType": + case "RowType": + default: + return new AnalysisException("Do not support convert trino type [" + type.getSimpleName() + + "] to domain values."); + } + } + + private SlotRef convertExprToSlotRef(Expr expr) { + SlotRef slotRef = null; + if (expr instanceof SlotRef) { + slotRef = (SlotRef) expr; + } else if (expr instanceof CastExpr) { + if (expr.getChild(0) instanceof SlotRef) { + slotRef = (SlotRef) expr.getChild(0); + } + } + return slotRef; + } + + private LiteralExpr convertExprToLiteral(Expr expr) { + LiteralExpr literalExpr = null; + if (expr instanceof LiteralExpr) { + literalExpr = (LiteralExpr) expr; + } else if (expr instanceof CastExpr) { + if (expr.getChild(0) instanceof LiteralExpr) { + literalExpr = (LiteralExpr) expr.getChild(0); + } + } + return literalExpr; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/datasource/trinoconnector/source/TrinoConnectorScanNode.java b/fe/fe-core/src/main/java/org/apache/doris/datasource/trinoconnector/source/TrinoConnectorScanNode.java index a143f08a870926..9c61d54614a070 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/datasource/trinoconnector/source/TrinoConnectorScanNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/datasource/trinoconnector/source/TrinoConnectorScanNode.java @@ -52,9 +52,12 @@ import io.airlift.json.ObjectMapperProvider; import io.trino.Session; import io.trino.SystemSessionProperties; +import io.trino.block.BlockJsonSerde; +import io.trino.metadata.BlockEncodingManager; import io.trino.metadata.HandleJsonModule; import io.trino.metadata.HandleResolver; -import io.trino.plugin.base.TypeDeserializer; +import io.trino.metadata.InternalBlockEncodingSerde; +import io.trino.spi.block.Block; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.connector.Connector; @@ -65,29 +68,37 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.connector.Constraint; +import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.connector.LimitApplicationResult; +import io.trino.spi.predicate.TupleDomain; import io.trino.spi.transaction.IsolationLevel; import io.trino.spi.type.TypeManager; import io.trino.split.BufferingSplitSource; import io.trino.split.ConnectorAwareSplitSource; import io.trino.split.SplitSource; import io.trino.type.InternalTypeManager; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.stream.Collectors; public class TrinoConnectorScanNode extends FileQueryScanNode { + private static final Logger LOG = LogManager.getLogger(TrinoConnectorScanNode.class); private static final int minScheduleSplitBatchSize = 10; private TrinoConnectorSource source = null; private ObjectMapperProvider objectMapperProvider; - // private static List predicates; + private ConnectorMetadata connectorMetadata; + private Constraint constraint; public TrinoConnectorScanNode(PlanNodeId id, TupleDescriptor desc, boolean needCheckColumnPriv) { super(id, desc, "TRINO_CONNECTOR_SCAN_NODE", StatisticalType.TRINO_CONNECTOR_SCAN_NODE, needCheckColumnPriv); @@ -102,11 +113,31 @@ protected void doInitialize() throws UserException { table.getName())); } - source = new TrinoConnectorSource(desc, table); - computeColumnsFilter(); initBackendPolicy(); + source = new TrinoConnectorSource(desc, table); initSchemaParams(); + convertPredicate(); + } + + protected void convertPredicate() throws UserException { + if (conjuncts.isEmpty()) { + constraint = Constraint.alwaysTrue(); + } + TupleDomain summary = TupleDomain.all(); + TrinoConnectorPredicateConverter trinoConnectorPredicateConverter = new TrinoConnectorPredicateConverter( + source.getTargetTable().getColumnHandleMap(), + source.getTargetTable().getColumnMetadataMap()); + try { + for (int i = 0; i < conjuncts.size(); ++i) { + summary = summary.intersect( + trinoConnectorPredicateConverter.convertExprToTrinoTupleDomain(conjuncts.get(i))); + } + } catch (AnalysisException e) { + LOG.warn("Can not convert Expr to trino tuple domain, exception: {}", e.getMessage()); + summary = TupleDomain.all(); + } + constraint = new Constraint(summary); } @Override @@ -117,29 +148,77 @@ public List getSplits() throws UserException { IsolationLevel.READ_UNCOMMITTED, true, true); source.setConnectorTransactionHandle(connectorTransactionHandle); ConnectorSession connectorSession = source.getTrinoSession().toConnectorSession(source.getCatalogHandle()); - ConnectorMetadata connectorMetadata = connector.getMetadata(connectorSession, connectorTransactionHandle); + connectorMetadata = connector.getMetadata(connectorSession, connectorTransactionHandle); // 2. Begin query connectorMetadata.beginQuery(connectorSession); + applyPushDown(connectorSession); // 3. get splitSource - SplitSource splitSource = getTrinoSplitSource(connector, source.getTrinoSession(), connectorTransactionHandle, - source.getTrinoConnectorExtTableHandle(), - DynamicFilter.EMPTY, - Constraint.alwaysTrue()); - // 4. get trino.Splits and convert it to doris.Splits List splits = Lists.newArrayList(); - while (!splitSource.isFinished()) { - for (io.trino.metadata.Split split : getNextSplitBatch(splitSource)) { - splits.add(new TrinoConnectorSplit(split.getConnectorSplit(), source.getConnectorName())); + try (SplitSource splitSource = getTrinoSplitSource(connector, source.getTrinoSession(), + connectorTransactionHandle, source.getTrinoConnectorTableHandle(), DynamicFilter.EMPTY)) { + // 4. get trino.Splits and convert it to doris.Splits + while (!splitSource.isFinished()) { + for (io.trino.metadata.Split split : getNextSplitBatch(splitSource)) { + splits.add(new TrinoConnectorSplit(split.getConnectorSplit(), source.getConnectorName())); + } } } return splits; } + private void applyPushDown(ConnectorSession connectorSession) { + // push down predicate/filter + Optional> filterResult + = connectorMetadata.applyFilter(connectorSession, source.getTrinoConnectorTableHandle(), constraint); + if (filterResult.isPresent()) { + source.setTrinoConnectorTableHandle(filterResult.get().getHandle()); + } + + // push down limit + if (hasLimit()) { + long limit = getLimit(); + Optional> limitResult + = connectorMetadata.applyLimit(connectorSession, source.getTrinoConnectorTableHandle(), limit); + if (limitResult.isPresent()) { + source.setTrinoConnectorTableHandle(limitResult.get().getHandle()); + } + } + + if (LOG.isDebugEnabled()) { + LOG.debug("The TrinoConnectorTableHandle is " + source.getTrinoConnectorTableHandle() + + " after pushing down."); + } + + // TODO(ftw): push down projection + // Map columnHandleMap = source.getTargetTable().getColumnHandleMap(); + // Map assignments = Maps.newHashMap(); + // if (source.getTargetTable().getName().equals("customer")) { + // assignments.put("c_custkey", columnHandleMap.get("c_custkey")); + // assignments.put("c_mktsegment", columnHandleMap.get("c_mktsegment")); + // } else if (source.getTargetTable().getName().equals("orders")) { + // assignments.put("o_orderkey", columnHandleMap.get("o_orderkey")); + // assignments.put("o_custkey", columnHandleMap.get("o_custkey")); + // assignments.put("o_orderdate", columnHandleMap.get("o_orderdate")); + // assignments.put("o_shippriority", columnHandleMap.get("o_shippriority")); + // } else if (source.getTargetTable().getName().equals("lineitem")) { + // assignments.put("l_orderkey", columnHandleMap.get("l_orderkey")); + // assignments.put("l_extendedprice", columnHandleMap.get("l_extendedprice")); + // assignments.put("l_discount", columnHandleMap.get("l_discount")); + // assignments.put("l_shipdate", columnHandleMap.get("l_shipdate")); + // } + // Optional> projectionResult + // = connectorMetadata.applyProjection(connectorSession, source.getTrinoConnectorTableHandle(), + // Lists.newArrayList(), assignments); + // if (projectionResult.isPresent()) { + // source.setTrinoConnectorTableHandle(projectionResult.get().getHandle()); + // } + } + private SplitSource getTrinoSplitSource(Connector connector, Session session, ConnectorTransactionHandle connectorTransactionHandle, ConnectorTableHandle table, - DynamicFilter dynamicFilter, Constraint constraint) { + DynamicFilter dynamicFilter) { ConnectorSplitManager splitManager = connector.getSplitManager(); if (!SystemSessionProperties.isAllowPushdownIntoConnectors(session)) { @@ -147,7 +226,7 @@ private SplitSource getTrinoSplitSource(Connector connector, Session session, } ConnectorSession connectorSession = session.toConnectorSession(source.getCatalogHandle()); - // TODO(ftw): here can not use table.getTransactionHandle + // Constraint is not used by Hive/BigQuery Connector ConnectorSplitSource connectorSplitSource = splitManager.getSplits(connectorTransactionHandle, connectorSession, table, dynamicFilter, constraint); @@ -183,8 +262,8 @@ public void setTrinoConnectorParams(TFileRangeDesc rangeDesc, TrinoConnectorSpli fileDesc.setDbName(source.getTargetTable().getDbName()); fileDesc.setTrinoConnectorOptions(source.getCatalog().getTrinoConnectorProperties()); fileDesc.setTableName(source.getTargetTable().getName()); - fileDesc.setTrinoConnectorTableHandle( - encodeObjectToString(source.getTargetTable().getConnectorTableHandle(), objectMapperProvider)); + fileDesc.setTrinoConnectorTableHandle(encodeObjectToString( + source.getTrinoConnectorTableHandle(), objectMapperProvider)); Map columnHandleMap = source.getTargetTable().getColumnHandleMap(); Map columnMetadataMap = source.getTargetTable().getColumnMetadataMap(); @@ -220,7 +299,6 @@ public void setTrinoConnectorParams(TFileRangeDesc rangeDesc, TrinoConnectorSpli private ObjectMapperProvider createObjectMapperProvider() { // mock ObjectMapperProvider - TypeManager typeManager = new InternalTypeManager(TrinoConnectorPluginLoader.getTypeRegistry()); ObjectMapperProvider objectMapperProvider = new ObjectMapperProvider(); Set modules = new HashSet(); HandleResolver handleResolver = TrinoConnectorPluginLoader.getHandleResolver(); @@ -233,9 +311,15 @@ private ObjectMapperProvider createObjectMapperProvider() { // modules.add(HandleJsonModule.tableExecuteHandleModule(handleResolver)); // modules.add(HandleJsonModule.indexHandleModule(handleResolver)); // modules.add(HandleJsonModule.partitioningHandleModule(handleResolver)); + // modules.add(HandleJsonModule.tableFunctionHandleModule(handleResolver)); objectMapperProvider.setModules(modules); - objectMapperProvider.setJsonDeserializers( - ImmutableMap.of(io.trino.spi.type.Type.class, new TypeDeserializer(typeManager))); + + // set json deserializers + TypeManager typeManager = new InternalTypeManager(TrinoConnectorPluginLoader.getTypeRegistry()); + InternalBlockEncodingSerde blockEncodingSerde = new InternalBlockEncodingSerde(new BlockEncodingManager(), + typeManager); + objectMapperProvider.setJsonSerializers(ImmutableMap.of(Block.class, + new BlockJsonSerde.Serializer(blockEncodingSerde))); return objectMapperProvider; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/datasource/trinoconnector/source/TrinoConnectorSource.java b/fe/fe-core/src/main/java/org/apache/doris/datasource/trinoconnector/source/TrinoConnectorSource.java index d9aea14bbaf086..85e36517baedb1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/datasource/trinoconnector/source/TrinoConnectorSource.java +++ b/fe/fe-core/src/main/java/org/apache/doris/datasource/trinoconnector/source/TrinoConnectorSource.java @@ -39,14 +39,14 @@ public class TrinoConnectorSource { private final Connector connector; private final ConnectorName connectorName; private ConnectorTransactionHandle connectorTransactionHandle; - private final ConnectorTableHandle trinoConnectorExtTableHandle; + private ConnectorTableHandle trinoConnectorTableHandle; public TrinoConnectorSource(TupleDescriptor desc, TrinoConnectorExternalTable table) { this.desc = desc; this.trinoConnectorExtTable = table; this.trinoConnectorExternalCatalog = (TrinoConnectorExternalCatalog) table.getCatalog(); this.catalogHandle = trinoConnectorExternalCatalog.getTrinoCatalogHandle(); - this.trinoConnectorExtTableHandle = table.getConnectorTableHandle(); + this.trinoConnectorTableHandle = table.getConnectorTableHandle(); this.trinoSession = trinoConnectorExternalCatalog.getTrinoSession(); this.connector = ((TrinoConnectorExternalCatalog) table.getCatalog()).getConnector(); this.connectorName = ((TrinoConnectorExternalCatalog) table.getCatalog()).getConnectorName(); @@ -56,8 +56,8 @@ public TupleDescriptor getDesc() { return desc; } - public ConnectorTableHandle getTrinoConnectorExtTableHandle() { - return trinoConnectorExtTableHandle; + public ConnectorTableHandle getTrinoConnectorTableHandle() { + return trinoConnectorTableHandle; } public TrinoConnectorExternalTable getTargetTable() { @@ -92,6 +92,10 @@ public void setConnectorTransactionHandle(ConnectorTransactionHandle connectorTr this.connectorTransactionHandle = connectorTransactionHandle; } + public void setTrinoConnectorTableHandle(ConnectorTableHandle trinoConnectorExtTableHandle) { + this.trinoConnectorTableHandle = trinoConnectorExtTableHandle; + } + public ConnectorTransactionHandle getConnectorTransactionHandle() { return connectorTransactionHandle; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/datasource/trinoconnector/TrinoConnectorPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/datasource/trinoconnector/TrinoConnectorPredicateTest.java new file mode 100644 index 00000000000000..1ec5654de985ab --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/datasource/trinoconnector/TrinoConnectorPredicateTest.java @@ -0,0 +1,735 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.datasource.trinoconnector; + +import org.apache.doris.analysis.BinaryPredicate; +import org.apache.doris.analysis.BinaryPredicate.Operator; +import org.apache.doris.analysis.BoolLiteral; +import org.apache.doris.analysis.CompoundPredicate; +import org.apache.doris.analysis.DateLiteral; +import org.apache.doris.analysis.DecimalLiteral; +import org.apache.doris.analysis.Expr; +import org.apache.doris.analysis.FloatLiteral; +import org.apache.doris.analysis.InPredicate; +import org.apache.doris.analysis.IntLiteral; +import org.apache.doris.analysis.LiteralExpr; +import org.apache.doris.analysis.NullLiteral; +import org.apache.doris.analysis.SlotRef; +import org.apache.doris.analysis.StringLiteral; +import org.apache.doris.analysis.TableName; +import org.apache.doris.catalog.Type; +import org.apache.doris.common.AnalysisException; +import org.apache.doris.datasource.trinoconnector.source.TrinoConnectorPredicateConverter; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import io.airlift.slice.Slices; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ColumnMetadata; +import io.trino.spi.predicate.Domain; +import io.trino.spi.predicate.Range; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.predicate.ValueSet; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.CharType; +import io.trino.spi.type.DateType; +import io.trino.spi.type.DecimalType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.Int128; +import io.trino.spi.type.IntegerType; +import io.trino.spi.type.LongTimestamp; +import io.trino.spi.type.LongTimestampWithTimeZone; +import io.trino.spi.type.RealType; +import io.trino.spi.type.SmallintType; +import io.trino.spi.type.TimeZoneKey; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.TimestampWithTimeZoneType; +import io.trino.spi.type.TinyintType; +import io.trino.spi.type.VarbinaryType; +import io.trino.spi.type.VarcharType; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.List; +import java.util.Objects; + +public class TrinoConnectorPredicateTest { + + private static final ImmutableMap trinoConnectorColumnHandleMap = + new ImmutableMap.Builder() + .put("c_bool", new MockColumnHandle("c_bool")) + .put("c_tinyint", new MockColumnHandle("c_tinyint")) + .put("c_smallint", new MockColumnHandle("c_smallint")) + .put("c_int", new MockColumnHandle("c_int")) + .put("c_bigint", new MockColumnHandle("c_bigint")) + .put("c_real", new MockColumnHandle("c_real")) + .put("c_short_decimal", new MockColumnHandle("c_short_decimal")) + .put("c_long_decimal", new MockColumnHandle("c_long_decimal")) + .put("c_char", new MockColumnHandle("c_char")) + .put("c_varchar", new MockColumnHandle("c_varchar")) + .put("c_varbinary", new MockColumnHandle("c_varbinary")) + .put("c_date", new MockColumnHandle("c_date")) + .put("c_double", new MockColumnHandle("c_double")) + .put("c_short_timestamp", new MockColumnHandle("c_short_timestamp")) + // .put("c_short_timestamp_timezone", new MockColumnHandle("c_short_timestamp_timezone")) + .put("c_long_timestamp", new MockColumnHandle("c_long_timestamp")) + .put("c_long_timestamp_timezone", new MockColumnHandle("c_long_timestamp_timezone")) + .build(); + + private static final ImmutableMap trinoConnectorColumnMetadataMap = + new ImmutableMap.Builder() + .put("c_bool", new ColumnMetadata("c_bool", BooleanType.BOOLEAN)) + .put("c_tinyint", new ColumnMetadata("c_tinyint", TinyintType.TINYINT)) + .put("c_smallint", new ColumnMetadata("c_smallint", SmallintType.SMALLINT)) + .put("c_int", new ColumnMetadata("c_int", IntegerType.INTEGER)) + .put("c_bigint", new ColumnMetadata("c_bigint", BigintType.BIGINT)) + .put("c_real", new ColumnMetadata("c_real", RealType.REAL)) + .put("c_short_decimal", new ColumnMetadata("c_short_decimal", + DecimalType.createDecimalType(9, 2))) + .put("c_long_decimal", new ColumnMetadata("c_long_decimal", + DecimalType.createDecimalType(38, 15))) + .put("c_char", new ColumnMetadata("c_char", CharType.createCharType(128))) + .put("c_varchar", new ColumnMetadata("c_varchar", + VarcharType.createVarcharType(128))) + .put("c_varbinary", new ColumnMetadata("c_varbinary", VarbinaryType.VARBINARY)) + .put("c_date", new ColumnMetadata("c_date", DateType.DATE)) + .put("c_double", new ColumnMetadata("c_double", DoubleType.DOUBLE)) + .put("c_short_timestamp", new ColumnMetadata("c_short_timestamp", + TimestampType.TIMESTAMP_MICROS)) + // .put("c_short_timestamp_timezone", new ColumnMetadata("c_short_timestamp_timezone", + // TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS)) + .put("c_long_timestamp", new ColumnMetadata("c_long_timestamp", + TimestampType.TIMESTAMP_PICOS)) + .put("c_long_timestamp_timezone", new ColumnMetadata("c_long_timestamp_timezone", + TimestampWithTimeZoneType.TIMESTAMP_TZ_PICOS)) + .build(); + + private static TrinoConnectorPredicateConverter trinoConnectorPredicateConverter; + + @BeforeClass + public static void before() throws AnalysisException { + trinoConnectorPredicateConverter = new TrinoConnectorPredicateConverter( + trinoConnectorColumnHandleMap, + trinoConnectorColumnMetadataMap); + } + + @Test + public void testBinaryEqPredicate() throws AnalysisException { + // construct slotRefs and literalLists + List slotRefs = mockSlotRefs(); + List literalList = mockLiteralExpr(); + + // expect results + List> expectTupleDomain = Lists.newArrayList(); + ImmutableList expectRanges = new ImmutableList.Builder() + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_bool").getType(), true)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_tinyint").getType(), 1L)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_smallint").getType(), 1L)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_int").getType(), 1L)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_bigint").getType(), 1L)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_real").getType(), + Long.valueOf(Float.floatToIntBits(1.23f)))) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_double").getType(), 3.1415926456)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_short_decimal").getType(), 12345623L)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_long_decimal").getType(), + Int128.valueOf(new BigInteger("12345678901234567890123123")))) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_char").getType(), + Slices.utf8Slice("trino connector char test"))) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_varchar").getType(), + Slices.utf8Slice("trino connector varchar test"))) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_varbinary").getType(), + Slices.utf8Slice("trino connector varbinary test"))) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_date").getType(), -1L)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_short_timestamp").getType(), + 1000001L)) + // .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_short_timestamp_timezone").getType(), + // 0L)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_long_timestamp").getType(), + new LongTimestamp(1000001L, 0))) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_long_timestamp_timezone").getType(), + LongTimestampWithTimeZone.fromEpochMillisAndFraction(1000L, 1000000, + TimeZoneKey.getTimeZoneKey("Asia/Shanghai")))) + .build(); + for (int i = 0; i < slotRefs.size(); i++) { + final String colName = slotRefs.get(i).getColumnName(); + Domain domain = Domain.create(ValueSet.ofRanges(Lists.newArrayList(expectRanges.get(i))), false); + TupleDomain tupleDomain = TupleDomain.withColumnDomains( + ImmutableMap.of(trinoConnectorColumnHandleMap.get(colName), domain)); + expectTupleDomain.add(tupleDomain); + } + + // test results, construct equal binary predicate + List> testTupleDomain = Lists.newArrayList(); + for (int i = 0; i < slotRefs.size(); i++) { + BinaryPredicate expr = new BinaryPredicate(BinaryPredicate.Operator.EQ, slotRefs.get(i), + literalList.get(i)); + TupleDomain tupleDomain = trinoConnectorPredicateConverter.convertExprToTrinoTupleDomain( + expr); + testTupleDomain.add(tupleDomain); + } + + // verify if `testTupleDomain` is equal to `expectTupleDomain`. + for (int i = 0; i < expectTupleDomain.size(); i++) { + Assert.assertTrue(expectTupleDomain.get(i).contains(testTupleDomain.get(i))); + } + } + + @Test + public void testBinaryEqualForNullPredicate() throws AnalysisException { + // construct slotRefs and literalLists + List slotRefs = mockSlotRefs(); + List literalList = mockLiteralExpr(); + + // expect results + List> expectTupleDomain = Lists.newArrayList(); + ImmutableList expectRanges = new ImmutableList.Builder() + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_bool").getType(), true)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_tinyint").getType(), 1L)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_smallint").getType(), 1L)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_int").getType(), 1L)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_bigint").getType(), 1L)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_real").getType(), + Long.valueOf(Float.floatToIntBits(1.23f)))) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_double").getType(), 3.1415926456)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_short_decimal").getType(), 12345623L)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_long_decimal").getType(), + Int128.valueOf(new BigInteger("12345678901234567890123123")))) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_char").getType(), + Slices.utf8Slice("trino connector char test"))) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_varchar").getType(), + Slices.utf8Slice("trino connector varchar test"))) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_varbinary").getType(), + Slices.utf8Slice("trino connector varbinary test"))) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_date").getType(), -1L)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_short_timestamp").getType(), + 1000001L)) + // .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_short_timestamp_timezone").getType(), + // 0L)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_long_timestamp").getType(), + new LongTimestamp(1000001L, 0))) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_long_timestamp_timezone").getType(), + LongTimestampWithTimeZone.fromEpochMillisAndFraction(1000L, 1000000, + TimeZoneKey.getTimeZoneKey("Asia/Shanghai")))) + .build(); + for (int i = 0; i < slotRefs.size(); i++) { + final String colName = slotRefs.get(i).getColumnName(); + Domain domain = Domain.create(ValueSet.ofRanges(Lists.newArrayList(expectRanges.get(i))), false); + TupleDomain tupleDomain = TupleDomain.withColumnDomains( + ImmutableMap.of(trinoConnectorColumnHandleMap.get(colName), domain)); + expectTupleDomain.add(tupleDomain); + } + + // test results, construct equal binary predicate + List> testTupleDomain = Lists.newArrayList(); + for (int i = 0; i < slotRefs.size(); i++) { + BinaryPredicate expr = new BinaryPredicate(Operator.EQ_FOR_NULL, slotRefs.get(i), + literalList.get(i)); + TupleDomain tupleDomain = trinoConnectorPredicateConverter.convertExprToTrinoTupleDomain( + expr); + testTupleDomain.add(tupleDomain); + } + + // verify if `testTupleDomain` is equal to `expectTupleDomain`. + for (int i = 0; i < expectTupleDomain.size(); i++) { + Assert.assertTrue(expectTupleDomain.get(i).contains(testTupleDomain.get(i))); + } + + // test <=> + SlotRef intSlot = new SlotRef(new TableName("test_table"), "c_int"); + NullLiteral nullLiteral = NullLiteral.create(Type.INT); + BinaryPredicate expr = new BinaryPredicate(Operator.EQ_FOR_NULL, intSlot, nullLiteral); + TupleDomain testNullTupleDomain = trinoConnectorPredicateConverter.convertExprToTrinoTupleDomain( + expr); + TupleDomain expectNullTupleDomain = TupleDomain.withColumnDomains( + ImmutableMap.of(trinoConnectorColumnHandleMap.get("c_int"), Domain.onlyNull(IntegerType.INTEGER))); + Assert.assertTrue(expectNullTupleDomain.contains(testNullTupleDomain)); + } + + @Test + public void testBinaryLessThanPredicate() throws AnalysisException { + // construct slotRefs and literalLists + List slotRefs = mockSlotRefs(); + List literalList = mockLiteralExpr(); + + // expect results + List> expectTupleDomain = Lists.newArrayList(); + ImmutableList expectRanges = new ImmutableList.Builder() + .add(Range.lessThan(trinoConnectorColumnMetadataMap.get("c_bool").getType(), true)) + .add(Range.lessThan(trinoConnectorColumnMetadataMap.get("c_tinyint").getType(), 1L)) + .add(Range.lessThan(trinoConnectorColumnMetadataMap.get("c_smallint").getType(), 1L)) + .add(Range.lessThan(trinoConnectorColumnMetadataMap.get("c_int").getType(), 1L)) + .add(Range.lessThan(trinoConnectorColumnMetadataMap.get("c_bigint").getType(), 1L)) + .add(Range.lessThan(trinoConnectorColumnMetadataMap.get("c_real").getType(), + Long.valueOf(Float.floatToIntBits(1.23f)))) + .add(Range.lessThan(trinoConnectorColumnMetadataMap.get("c_double").getType(), 3.1415926456)) + .add(Range.lessThan(trinoConnectorColumnMetadataMap.get("c_short_decimal").getType(), 12345623L)) + .add(Range.lessThan(trinoConnectorColumnMetadataMap.get("c_long_decimal").getType(), + Int128.valueOf(new BigInteger("12345678901234567890123123")))) + .add(Range.lessThan(trinoConnectorColumnMetadataMap.get("c_char").getType(), + Slices.utf8Slice("trino connector char test"))) + .add(Range.lessThan(trinoConnectorColumnMetadataMap.get("c_varchar").getType(), + Slices.utf8Slice("trino connector varchar test"))) + .add(Range.lessThan(trinoConnectorColumnMetadataMap.get("c_varbinary").getType(), + Slices.utf8Slice("trino connector varbinary test"))) + .add(Range.lessThan(trinoConnectorColumnMetadataMap.get("c_date").getType(), -1L)) + .add(Range.lessThan(trinoConnectorColumnMetadataMap.get("c_short_timestamp").getType(), + 1000001L)) + // .add(Range.lessThan(trinoConnectorColumnMetadataMap.get("c_short_timestamp_timezone").getType(), + // 0L)) + .add(Range.lessThan(trinoConnectorColumnMetadataMap.get("c_long_timestamp").getType(), + new LongTimestamp(1000001L, 0))) + .add(Range.lessThan(trinoConnectorColumnMetadataMap.get("c_long_timestamp_timezone").getType(), + LongTimestampWithTimeZone.fromEpochMillisAndFraction(1000L, 1000000, + TimeZoneKey.getTimeZoneKey("Asia/Shanghai")))) + .build(); + for (int i = 0; i < slotRefs.size(); i++) { + final String colName = slotRefs.get(i).getColumnName(); + Domain domain = Domain.create(ValueSet.ofRanges(Lists.newArrayList(expectRanges.get(i))), false); + TupleDomain tupleDomain = TupleDomain.withColumnDomains( + ImmutableMap.of(trinoConnectorColumnHandleMap.get(colName), domain)); + expectTupleDomain.add(tupleDomain); + } + + // test results, construct lessThan binary predicate + List> testTupleDomain = Lists.newArrayList(); + for (int i = 0; i < slotRefs.size(); i++) { + BinaryPredicate expr = new BinaryPredicate(Operator.LT, slotRefs.get(i), + literalList.get(i)); + TupleDomain tupleDomain = trinoConnectorPredicateConverter.convertExprToTrinoTupleDomain( + expr); + testTupleDomain.add(tupleDomain); + } + + // verify if `testTupleDomain` is equal to `expectTupleDomain`. + for (int i = 0; i < expectTupleDomain.size(); i++) { + Assert.assertTrue(expectTupleDomain.get(i).contains(testTupleDomain.get(i))); + } + } + + @Test + public void testBinaryLessEqualPredicate() throws AnalysisException { + // construct slotRefs and literalLists + List slotRefs = mockSlotRefs(); + List literalList = mockLiteralExpr(); + + // expect results + List> expectTupleDomain = Lists.newArrayList(); + ImmutableList expectRanges = new ImmutableList.Builder() + .add(Range.lessThanOrEqual(trinoConnectorColumnMetadataMap.get("c_bool").getType(), true)) + .add(Range.lessThanOrEqual(trinoConnectorColumnMetadataMap.get("c_tinyint").getType(), 1L)) + .add(Range.lessThanOrEqual(trinoConnectorColumnMetadataMap.get("c_smallint").getType(), 1L)) + .add(Range.lessThanOrEqual(trinoConnectorColumnMetadataMap.get("c_int").getType(), 1L)) + .add(Range.lessThanOrEqual(trinoConnectorColumnMetadataMap.get("c_bigint").getType(), 1L)) + .add(Range.lessThanOrEqual(trinoConnectorColumnMetadataMap.get("c_real").getType(), + Long.valueOf(Float.floatToIntBits(1.23f)))) + .add(Range.lessThanOrEqual(trinoConnectorColumnMetadataMap.get("c_double").getType(), 3.1415926456)) + .add(Range.lessThanOrEqual(trinoConnectorColumnMetadataMap.get("c_short_decimal").getType(), 12345623L)) + .add(Range.lessThanOrEqual(trinoConnectorColumnMetadataMap.get("c_long_decimal").getType(), + Int128.valueOf(new BigInteger("12345678901234567890123123")))) + .add(Range.lessThanOrEqual(trinoConnectorColumnMetadataMap.get("c_char").getType(), + Slices.utf8Slice("trino connector char test"))) + .add(Range.lessThanOrEqual(trinoConnectorColumnMetadataMap.get("c_varchar").getType(), + Slices.utf8Slice("trino connector varchar test"))) + .add(Range.lessThanOrEqual(trinoConnectorColumnMetadataMap.get("c_varbinary").getType(), + Slices.utf8Slice("trino connector varbinary test"))) + .add(Range.lessThanOrEqual(trinoConnectorColumnMetadataMap.get("c_date").getType(), -1L)) + .add(Range.lessThanOrEqual(trinoConnectorColumnMetadataMap.get("c_short_timestamp").getType(), + 1000001L)) + // .add(Range.lessThanOrEqual(trinoConnectorColumnMetadataMap.get("c_short_timestamp_timezone").getType(), + // 0L)) + .add(Range.lessThanOrEqual(trinoConnectorColumnMetadataMap.get("c_long_timestamp").getType(), + new LongTimestamp(1000001L, 0))) + .add(Range.lessThanOrEqual(trinoConnectorColumnMetadataMap.get("c_long_timestamp_timezone").getType(), + LongTimestampWithTimeZone.fromEpochMillisAndFraction(1000L, 1000000, + TimeZoneKey.getTimeZoneKey("Asia/Shanghai")))) + .build(); + for (int i = 0; i < slotRefs.size(); i++) { + final String colName = slotRefs.get(i).getColumnName(); + Domain domain = Domain.create(ValueSet.ofRanges(Lists.newArrayList(expectRanges.get(i))), false); + TupleDomain tupleDomain = TupleDomain.withColumnDomains( + ImmutableMap.of(trinoConnectorColumnHandleMap.get(colName), domain)); + expectTupleDomain.add(tupleDomain); + } + + // test results, construct lessThanOrEqual binary predicate + List> testTupleDomain = Lists.newArrayList(); + for (int i = 0; i < slotRefs.size(); i++) { + BinaryPredicate expr = new BinaryPredicate(Operator.LE, slotRefs.get(i), + literalList.get(i)); + TupleDomain tupleDomain = trinoConnectorPredicateConverter.convertExprToTrinoTupleDomain( + expr); + testTupleDomain.add(tupleDomain); + } + + // verify if `testTupleDomain` is equal to `expectTupleDomain`. + for (int i = 0; i < expectTupleDomain.size(); i++) { + Assert.assertTrue(expectTupleDomain.get(i).contains(testTupleDomain.get(i))); + } + } + + @Test + public void testBinaryGreatThanPredicate() throws AnalysisException { + // construct slotRefs and literalLists + List slotRefs = mockSlotRefs(); + List literalList = mockLiteralExpr(); + + // expect results + List> expectTupleDomain = Lists.newArrayList(); + ImmutableList expectRanges = new ImmutableList.Builder() + .add(Range.greaterThan(trinoConnectorColumnMetadataMap.get("c_bool").getType(), true)) + .add(Range.greaterThan(trinoConnectorColumnMetadataMap.get("c_tinyint").getType(), 1L)) + .add(Range.greaterThan(trinoConnectorColumnMetadataMap.get("c_smallint").getType(), 1L)) + .add(Range.greaterThan(trinoConnectorColumnMetadataMap.get("c_int").getType(), 1L)) + .add(Range.greaterThan(trinoConnectorColumnMetadataMap.get("c_bigint").getType(), 1L)) + .add(Range.greaterThan(trinoConnectorColumnMetadataMap.get("c_real").getType(), + Long.valueOf(Float.floatToIntBits(1.23f)))) + .add(Range.greaterThan(trinoConnectorColumnMetadataMap.get("c_double").getType(), 3.1415926456)) + .add(Range.greaterThan(trinoConnectorColumnMetadataMap.get("c_short_decimal").getType(), 12345623L)) + .add(Range.greaterThan(trinoConnectorColumnMetadataMap.get("c_long_decimal").getType(), + Int128.valueOf(new BigInteger("12345678901234567890123123")))) + .add(Range.greaterThan(trinoConnectorColumnMetadataMap.get("c_char").getType(), + Slices.utf8Slice("trino connector char test"))) + .add(Range.greaterThan(trinoConnectorColumnMetadataMap.get("c_varchar").getType(), + Slices.utf8Slice("trino connector varchar test"))) + .add(Range.greaterThan(trinoConnectorColumnMetadataMap.get("c_varbinary").getType(), + Slices.utf8Slice("trino connector varbinary test"))) + .add(Range.greaterThan(trinoConnectorColumnMetadataMap.get("c_date").getType(), -1L)) + .add(Range.greaterThan(trinoConnectorColumnMetadataMap.get("c_short_timestamp").getType(), + 1000001L)) + // .add(Range.greaterThan(trinoConnectorColumnMetadataMap.get("c_short_timestamp_timezone").getType(), + // 0L)) + .add(Range.greaterThan(trinoConnectorColumnMetadataMap.get("c_long_timestamp").getType(), + new LongTimestamp(1000001L, 0))) + .add(Range.greaterThan(trinoConnectorColumnMetadataMap.get("c_long_timestamp_timezone").getType(), + LongTimestampWithTimeZone.fromEpochMillisAndFraction(1000L, 1000000, + TimeZoneKey.getTimeZoneKey("Asia/Shanghai")))) + .build(); + for (int i = 0; i < slotRefs.size(); i++) { + final String colName = slotRefs.get(i).getColumnName(); + Domain domain = Domain.create(ValueSet.ofRanges(Lists.newArrayList(expectRanges.get(i))), false); + TupleDomain tupleDomain = TupleDomain.withColumnDomains( + ImmutableMap.of(trinoConnectorColumnHandleMap.get(colName), domain)); + expectTupleDomain.add(tupleDomain); + } + + // test results, construct greaterThan binary predicate + List> testTupleDomain = Lists.newArrayList(); + for (int i = 0; i < slotRefs.size(); i++) { + BinaryPredicate expr = new BinaryPredicate(Operator.GT, slotRefs.get(i), + literalList.get(i)); + TupleDomain tupleDomain = trinoConnectorPredicateConverter.convertExprToTrinoTupleDomain( + expr); + testTupleDomain.add(tupleDomain); + } + + // verify if `testTupleDomain` is equal to `expectTupleDomain`. + for (int i = 0; i < expectTupleDomain.size(); i++) { + Assert.assertTrue(expectTupleDomain.get(i).contains(testTupleDomain.get(i))); + } + } + + @Test + public void testBinaryGreaterEqualPredicate() throws AnalysisException { + // construct slotRefs and literalLists + List slotRefs = mockSlotRefs(); + List literalList = mockLiteralExpr(); + + // expect results + List> expectTupleDomain = Lists.newArrayList(); + ImmutableList expectRanges = new ImmutableList.Builder() + .add(Range.greaterThanOrEqual(trinoConnectorColumnMetadataMap.get("c_bool").getType(), true)) + .add(Range.greaterThanOrEqual(trinoConnectorColumnMetadataMap.get("c_tinyint").getType(), 1L)) + .add(Range.greaterThanOrEqual(trinoConnectorColumnMetadataMap.get("c_smallint").getType(), 1L)) + .add(Range.greaterThanOrEqual(trinoConnectorColumnMetadataMap.get("c_int").getType(), 1L)) + .add(Range.greaterThanOrEqual(trinoConnectorColumnMetadataMap.get("c_bigint").getType(), 1L)) + .add(Range.greaterThanOrEqual(trinoConnectorColumnMetadataMap.get("c_real").getType(), + Long.valueOf(Float.floatToIntBits(1.23f)))) + .add(Range.greaterThanOrEqual(trinoConnectorColumnMetadataMap.get("c_double").getType(), 3.1415926456)) + .add(Range.greaterThanOrEqual(trinoConnectorColumnMetadataMap.get("c_short_decimal").getType(), 12345623L)) + .add(Range.greaterThanOrEqual(trinoConnectorColumnMetadataMap.get("c_long_decimal").getType(), + Int128.valueOf(new BigInteger("12345678901234567890123123")))) + .add(Range.greaterThanOrEqual(trinoConnectorColumnMetadataMap.get("c_char").getType(), + Slices.utf8Slice("trino connector char test"))) + .add(Range.greaterThanOrEqual(trinoConnectorColumnMetadataMap.get("c_varchar").getType(), + Slices.utf8Slice("trino connector varchar test"))) + .add(Range.greaterThanOrEqual(trinoConnectorColumnMetadataMap.get("c_varbinary").getType(), + Slices.utf8Slice("trino connector varbinary test"))) + .add(Range.greaterThanOrEqual(trinoConnectorColumnMetadataMap.get("c_date").getType(), -1L)) + .add(Range.greaterThanOrEqual(trinoConnectorColumnMetadataMap.get("c_short_timestamp").getType(), + 1000001L)) + // .add(Range.greaterThanOrEqual(trinoConnectorColumnMetadataMap.get("c_short_timestamp_timezone").getType(), + // 0L)) + .add(Range.greaterThanOrEqual(trinoConnectorColumnMetadataMap.get("c_long_timestamp").getType(), + new LongTimestamp(1000001L, 0))) + .add(Range.greaterThanOrEqual(trinoConnectorColumnMetadataMap.get("c_long_timestamp_timezone").getType(), + LongTimestampWithTimeZone.fromEpochMillisAndFraction(1000L, 1000000, + TimeZoneKey.getTimeZoneKey("Asia/Shanghai")))) + .build(); + for (int i = 0; i < slotRefs.size(); i++) { + final String colName = slotRefs.get(i).getColumnName(); + Domain domain = Domain.create(ValueSet.ofRanges(Lists.newArrayList(expectRanges.get(i))), false); + TupleDomain tupleDomain = TupleDomain.withColumnDomains( + ImmutableMap.of(trinoConnectorColumnHandleMap.get(colName), domain)); + expectTupleDomain.add(tupleDomain); + } + + // test results, construct greaterThanOrEqual binary predicate + List> testTupleDomain = Lists.newArrayList(); + for (int i = 0; i < slotRefs.size(); i++) { + BinaryPredicate expr = new BinaryPredicate(Operator.GE, slotRefs.get(i), + literalList.get(i)); + TupleDomain tupleDomain = trinoConnectorPredicateConverter.convertExprToTrinoTupleDomain( + expr); + testTupleDomain.add(tupleDomain); + } + + // verify if `testTupleDomain` is equal to `expectTupleDomain`. + for (int i = 0; i < expectTupleDomain.size(); i++) { + Assert.assertTrue(expectTupleDomain.get(i).contains(testTupleDomain.get(i))); + } + } + + @Test + public void testInPredicate() throws AnalysisException { + // construct slotRefs and literalLists + List slotRefs = mockSlotRefs(); + List literalList = mockLiteralExpr(); + + // expect results + List> expectInTupleDomain = Lists.newArrayList(); + List> expectNotInTupleDomain = Lists.newArrayList(); + ImmutableList expectRanges = new ImmutableList.Builder() + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_bool").getType(), true)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_tinyint").getType(), 1L)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_smallint").getType(), 1L)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_int").getType(), 1L)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_bigint").getType(), 1L)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_real").getType(), + Long.valueOf(Float.floatToIntBits(1.23f)))) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_double").getType(), 3.1415926456)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_short_decimal").getType(), 12345623L)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_long_decimal").getType(), + Int128.valueOf(new BigInteger("12345678901234567890123123")))) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_char").getType(), + Slices.utf8Slice("trino connector char test"))) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_varchar").getType(), + Slices.utf8Slice("trino connector varchar test"))) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_varbinary").getType(), + Slices.utf8Slice("trino connector varbinary test"))) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_date").getType(), -1L)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_short_timestamp").getType(), + 1000001L)) + // .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_short_timestamp_timezone").getType(), + // 0L)) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_long_timestamp").getType(), + new LongTimestamp(1000001L, 0))) + .add(Range.equal(trinoConnectorColumnMetadataMap.get("c_long_timestamp_timezone").getType(), + LongTimestampWithTimeZone.fromEpochMillisAndFraction(1000L, 1000000, + TimeZoneKey.getTimeZoneKey("Asia/Shanghai")))) + .build(); + + for (int i = 0; i < slotRefs.size(); i++) { + final String colName = slotRefs.get(i).getColumnName(); + Domain inDomain = Domain.create( + ValueSet.ofRanges(Lists.newArrayList(expectRanges.get(i))), false); + Domain notInDomain = Domain.create(ValueSet.all(trinoConnectorColumnMetadataMap.get(colName).getType()) + .subtract(ValueSet.ofRanges(expectRanges.get(i))), false); + TupleDomain inTupleDomain = TupleDomain.withColumnDomains( + ImmutableMap.of(trinoConnectorColumnHandleMap.get(colName), inDomain)); + TupleDomain notInTupleDomain = TupleDomain.withColumnDomains( + ImmutableMap.of(trinoConnectorColumnHandleMap.get(colName), notInDomain)); + expectInTupleDomain.add(inTupleDomain); + expectNotInTupleDomain.add(notInTupleDomain); + } + + // test results, construct equal binary predicate + List> testTupleDomain = Lists.newArrayList(); + for (int i = 0; i < slotRefs.size(); i++) { + InPredicate expr = new InPredicate(slotRefs.get(i), literalList.get(i), false); + TupleDomain tupleDomain = trinoConnectorPredicateConverter.convertExprToTrinoTupleDomain( + expr); + testTupleDomain.add(tupleDomain); + } + // verify if `testTupleDomain` is equal to `expectTupleDomain`. + for (int i = 0; i < expectInTupleDomain.size(); i++) { + Assert.assertTrue(expectInTupleDomain.get(i).contains(testTupleDomain.get(i))); + } + + testTupleDomain.clear(); + for (int i = 0; i < slotRefs.size(); i++) { + InPredicate expr = new InPredicate(slotRefs.get(i), literalList.get(i), true); + TupleDomain tupleDomain = trinoConnectorPredicateConverter.convertExprToTrinoTupleDomain( + expr); + testTupleDomain.add(tupleDomain); + } + // verify if `testTupleDomain` is equal to `expectTupleDomain`. + for (int i = 0; i < expectNotInTupleDomain.size(); i++) { + Assert.assertTrue(expectNotInTupleDomain.get(i).contains(testTupleDomain.get(i))); + } + } + + @Test + public void testCompoundPredicate() throws AnalysisException { + // construct slotRefs and literalLists + List slotRefs = mockSlotRefs(); + List literalList = mockLiteralExpr(); + + // valid expr + List validExprs = Lists.newArrayList(); + for (int i = 0; i < slotRefs.size(); i++) { + BinaryPredicate expr = new BinaryPredicate(BinaryPredicate.Operator.EQ, slotRefs.get(i), + literalList.get(i)); + validExprs.add(expr); + } + + // invalid expr + BinaryPredicate invalidExpr = new BinaryPredicate(BinaryPredicate.Operator.EQ, + literalList.get(0), literalList.get(0)); + + // AND + // valid AND valid + for (int i = 0; i < validExprs.size(); i++) { + for (int j = 0; j < validExprs.size(); j++) { + CompoundPredicate andPredicate = new CompoundPredicate(CompoundPredicate.Operator.AND, + validExprs.get(i), validExprs.get(j)); + trinoConnectorPredicateConverter.convertExprToTrinoTupleDomain(andPredicate); + } + } + + // valid AND invalid + CompoundPredicate andPredicate = new CompoundPredicate(CompoundPredicate.Operator.AND, + validExprs.get(0), invalidExpr); + trinoConnectorPredicateConverter.convertExprToTrinoTupleDomain(andPredicate); + + // invalid AND valid + andPredicate = new CompoundPredicate(CompoundPredicate.Operator.AND, invalidExpr, validExprs.get(0)); + trinoConnectorPredicateConverter.convertExprToTrinoTupleDomain(andPredicate); + + // invalid AND invalid + andPredicate = new CompoundPredicate(CompoundPredicate.Operator.AND, invalidExpr, invalidExpr); + try { + trinoConnectorPredicateConverter.convertExprToTrinoTupleDomain(andPredicate); + } catch (AnalysisException e) { + Assert.assertTrue(e.getMessage().contains("Can not convert both sides of compound predicate")); + } + + // OR + // valid OR valid + for (int i = 0; i < validExprs.size(); i++) { + for (int j = 0; j < validExprs.size(); j++) { + CompoundPredicate orPredicate = new CompoundPredicate(CompoundPredicate.Operator.OR, + validExprs.get(i), validExprs.get(j)); + trinoConnectorPredicateConverter.convertExprToTrinoTupleDomain(orPredicate); + } + } + + // // valid OR valid + try { + CompoundPredicate orPredicate = new CompoundPredicate(CompoundPredicate.Operator.AND, + validExprs.get(0), invalidExpr); + trinoConnectorPredicateConverter.convertExprToTrinoTupleDomain(orPredicate); + } catch (AnalysisException e) { + Assert.assertTrue(e.getMessage().contains("slotRef is null in binaryPredicateConverter")); + } + } + + private List mockSlotRefs() { + return new ImmutableList.Builder() + .add(new SlotRef(new TableName("test_table"), "c_bool")) + + .add(new SlotRef(new TableName("test_table"), "c_tinyint")) + .add(new SlotRef(new TableName("test_table"), "c_smallint")) + .add(new SlotRef(new TableName("test_table"), "c_int")) + .add(new SlotRef(new TableName("test_table"), "c_bigint")) + + .add(new SlotRef(new TableName("test_table"), "c_real")) + .add(new SlotRef(new TableName("test_table"), "c_double")) + + .add(new SlotRef(new TableName("test_table"), "c_short_decimal")) + .add(new SlotRef(new TableName("test_table"), "c_long_decimal")) + + .add(new SlotRef(new TableName("test_table"), "c_char")) + .add(new SlotRef(new TableName("test_table"), "c_varchar")) + .add(new SlotRef(new TableName("test_table"), "c_varbinary")) + + .add(new SlotRef(new TableName("test_table"), "c_date")) + .add(new SlotRef(new TableName("test_table"), "c_short_timestamp")) + // .add(new SlotRef(new TableName("test_table"), "c_short_timestamp_timezone")) + .add(new SlotRef(new TableName("test_table"), "c_long_timestamp")) + .add(new SlotRef(new TableName("test_table"), "c_long_timestamp_timezone")) + .build(); + } + + private List mockLiteralExpr() throws AnalysisException { + return new ImmutableList.Builder() + // boolean + .add(new BoolLiteral(true)) + // Integer + .add(new IntLiteral(1, Type.TINYINT)) + .add(new IntLiteral(1, Type.SMALLINT)) + .add(new IntLiteral(1, Type.INT)) + .add(new IntLiteral(1, Type.BIGINT)) + + .add(new FloatLiteral(1.23, Type.FLOAT)) // Real type + .add(new FloatLiteral(3.1415926456, Type.DOUBLE)) + + .add(new DecimalLiteral(new BigDecimal("123456.23"))) + .add(new DecimalLiteral(new BigDecimal("12345678901234567890123.123"))) + + .add(new StringLiteral("trino connector char test")) + .add(new StringLiteral("trino connector varchar test")) + .add(new StringLiteral("trino connector varbinary test")) + + .add(new DateLiteral("1969-12-31", Type.DATEV2)) + .add(new DateLiteral("1970-01-01 00:00:01.000001", Type.DATETIMEV2)) + // .add(new DateLiteral("1970-01-01 00:00:00.000000", Type.DATETIMEV2)) + .add(new DateLiteral("1970-01-01 00:00:01.000001", Type.DATETIMEV2)) + .add(new DateLiteral("1970-01-01 08:00:01.000001", Type.DATETIMEV2)) + .build(); + } + + private static class MockColumnHandle implements ColumnHandle { + private String colName; + + MockColumnHandle(String colName) { + this.colName = colName; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + MockColumnHandle that = (MockColumnHandle) o; + return colName.equals(that.colName); + } + + @Override + public int hashCode() { + return Objects.hash(colName); + } + } +}