From 7bb5957580cf7a0c26426458681ac9ce03f35c47 Mon Sep 17 00:00:00 2001 From: avv Date: Wed, 27 Nov 2024 10:41:30 +0500 Subject: [PATCH 1/4] ADH-5332 - implemented pushdown expression rewriters - implemented aggregation function rewriters - reformat code - fixed jdbc splitting bug --- plugin/trino-adb/pom.xml | 10 + .../io/trino/plugin/adb/AdbPluginConfig.java | 17 +- .../plugin/adb/connector/AdbClientModule.java | 8 +- .../adb/connector/AdbPushdownConfig.java | 175 +++++++++++++ .../AdbPushdownSessionProperties.java | 156 +++++++++++ .../adb/connector/AdbSessionProperties.java | 26 +- .../plugin/adb/connector/AdbSqlClient.java | 245 ++++++++++++++---- .../AdbBaseImplementStddevVariance.java | 80 ++++++ .../aggregation/AdbImplementAvgBigint.java | 26 ++ .../aggregation/AdbImplementMinMax.java | 70 +++++ .../aggregation/AdbImplementStddevPop.java | 23 ++ .../aggregation/AdbImplementStddevSamp.java | 23 ++ .../aggregation/AdbImplementVariancePop.java | 23 ++ .../aggregation/AdbImplementVarianceSamp.java | 23 ++ .../connector/datatype/VarcharDataType.java | 3 +- .../datatype/mapper/DataTypeMapperImpl.java | 120 ++++++--- .../connector/decode/csv/CsvRowDecoder.java | 15 +- .../connector/encode/AbstractRowEncoder.java | 67 +++-- .../connector/encode/csv/CsvFormatConfig.java | 14 +- .../connector/encode/csv/CsvRowEncoder.java | 2 +- .../expression/AdbRewriteBooleanConstant.java | 58 +++++ .../connector/expression/AdbRewriteCast.java | 90 +++++++ .../expression/AdbRewriteCharConstant.java | 59 +++++ .../AdbRewriteDatetimeArithmetics.java | 109 ++++++++ .../AdbRewriteInexactNumericConstant.java | 61 +++++ .../expression/AdbRewriteNullConstant.java | 78 ++++++ .../connector/metadata/AdbMetadataDao.java | 3 +- .../metadata/impl/AdbMetadataDaoImpl.java | 12 +- .../AbstractExternalTableQueryFactory.java | 2 +- .../protocol/gpfdist/GpfdistModule.java | 30 ++- .../load/process/GpfdistPageSinkProvider.java | 9 +- .../ExternalTableFormatConfigFactoryImpl.java | 2 +- .../metadata/GpfdistLocationFactoryImpl.java | 3 +- .../gpfdist/metadata/GpfdistMetadata.java | 5 +- .../gpfdist/server/GpfdistResource.java | 27 +- .../request/GpfdistWritableRequest.java | 3 +- .../process/GpfdistCsvDataProcessor.java | 6 +- .../unload/process/GpfdistRecordCursor.java | 3 +- .../process/GpfdistRecordSetProvider.java | 8 +- ...eateWritableExternalTableQueryFactory.java | 2 +- .../CollationAwareQueryBuilder.java | 16 +- .../{ => table}/AdbColumnMapping.java | 2 +- .../connector/{ => table}/AdbJdbcSplit.java | 2 +- .../connector/table/AdbTableProperties.java | 24 +- .../table/SplitSourceManagerImpl.java | 12 +- .../connector/table/StatisticsManager.java | 3 +- .../table/StatisticsManagerImpl.java | 21 +- 47 files changed, 1569 insertions(+), 207 deletions(-) create mode 100644 plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbPushdownConfig.java create mode 100644 plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbPushdownSessionProperties.java create mode 100644 plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbBaseImplementStddevVariance.java create mode 100644 plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementAvgBigint.java create mode 100644 plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementMinMax.java create mode 100644 plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementStddevPop.java create mode 100644 plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementStddevSamp.java create mode 100644 plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementVariancePop.java create mode 100644 plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementVarianceSamp.java create mode 100644 plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteBooleanConstant.java create mode 100644 plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteCast.java create mode 100644 plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteCharConstant.java create mode 100644 plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteDatetimeArithmetics.java create mode 100644 plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteInexactNumericConstant.java create mode 100644 plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteNullConstant.java rename plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/{ => query}/CollationAwareQueryBuilder.java (82%) rename plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/{ => table}/AdbColumnMapping.java (94%) rename plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/{ => table}/AdbJdbcSplit.java (98%) diff --git a/plugin/trino-adb/pom.xml b/plugin/trino-adb/pom.xml index 506fcc07c5815..c91adb248c2a3 100644 --- a/plugin/trino-adb/pom.xml +++ b/plugin/trino-adb/pom.xml @@ -135,6 +135,10 @@ io.opentelemetry opentelemetry-context + + io.trino + trino-matching + io.trino trino-plugin-toolkit @@ -150,6 +154,12 @@ + + io.trino + trino-matching + compile + + io.trino trino-plugin-toolkit diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/AdbPluginConfig.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/AdbPluginConfig.java index 2dc11487d2efd..d2baf53f53a85 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/AdbPluginConfig.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/AdbPluginConfig.java @@ -15,6 +15,7 @@ import io.airlift.configuration.Config; import io.airlift.configuration.ConfigDescription; +import io.airlift.configuration.LegacyConfig; import io.airlift.units.DataSize; import io.airlift.units.Duration; import io.airlift.units.MinDataSize; @@ -32,6 +33,7 @@ public class AdbPluginConfig private DataSize readBufferSize = DataSize.of(64L, DataSize.Unit.MEGABYTE); private final TransferDataProtocol dataProtocol = TransferDataProtocol.GPFDIST; private Duration gpfdistRetryTimeout; + private boolean enableStringPushdownWithCollate; public TransferDataProtocol getDataProtocol() { @@ -51,6 +53,19 @@ public AdbPluginConfig setArrayMapping(AdbPluginConfig.ArrayMapping arrayMapping return this; } + public boolean isEnableStringPushdownWithCollate() + { + return this.enableStringPushdownWithCollate; + } + + @Config("adb.enable-string-pushdown-with-collate") + @LegacyConfig("adb.experimental.enable-string-pushdown-with-collate") + public AdbPluginConfig setEnableStringPushdownWithCollate(boolean enableStringPushdownWithCollate) + { + this.enableStringPushdownWithCollate = enableStringPushdownWithCollate; + return this; + } + public boolean isIncludeSystemTables() { return this.includeSystemTables; @@ -112,7 +127,7 @@ public Duration getGpfdistRetryTimeout() return this.gpfdistRetryTimeout; } - @Config("adb.gpfdist.retry-timeout") + @Config("adb.connector.gpfdist.retry-timeout") @ConfigDescription("Value of adb gpfdist_retry_timeout property. Defaults to null (use adb defaults)") public AdbPluginConfig setGpfdistRetryTimeout(Duration gpfdistRetryTimeout) { diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbClientModule.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbClientModule.java index 6b50552ea9195..c6b44f6abaaf5 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbClientModule.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbClientModule.java @@ -27,6 +27,7 @@ import io.trino.plugin.adb.connector.metadata.impl.AdbMetadataDaoImpl; import io.trino.plugin.adb.connector.protocol.TransferDataProtocol; import io.trino.plugin.adb.connector.protocol.gpfdist.GpfdistModule; +import io.trino.plugin.adb.connector.query.CollationAwareQueryBuilder; import io.trino.plugin.adb.connector.table.AdbCreateTableStorageConfig; import io.trino.plugin.adb.connector.table.AdbTableProperties; import io.trino.plugin.adb.connector.table.SplitSourceManager; @@ -59,12 +60,15 @@ protected void setup(Binder binder) ConfigBinder.configBinder(binder).bindConfig(AdbCreateTableStorageConfig.class); ConfigBinder.configBinder(binder).bindConfig(JdbcStatisticsConfig.class); JdbcModule.bindSessionPropertiesProvider(binder, AdbSessionProperties.class); + JdbcModule.bindSessionPropertiesProvider(binder, AdbPushdownSessionProperties.class); JdbcModule.bindTablePropertiesProvider(binder, AdbTableProperties.class); - OptionalBinder.newOptionalBinder(binder, QueryBuilder.class).setBinding().to(CollationAwareQueryBuilder.class).in(Scopes.SINGLETON); + OptionalBinder.newOptionalBinder(binder, QueryBuilder.class).setBinding().to(CollationAwareQueryBuilder.class) + .in(Scopes.SINGLETON); this.install(new DecimalModule()); this.install(new JdbcJoinPushdownSupportModule()); this.install(new RemoteQueryCancellationModule()); - Multibinder.newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(Scopes.SINGLETON); + Multibinder.newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class) + .in(Scopes.SINGLETON); AdbPluginConfig pluginConfig = this.buildConfigObject(AdbPluginConfig.class); if (pluginConfig.getDataProtocol() == TransferDataProtocol.GPFDIST) { this.install(new GpfdistModule()); diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbPushdownConfig.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbPushdownConfig.java new file mode 100644 index 0000000000000..38c038a827538 --- /dev/null +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbPushdownConfig.java @@ -0,0 +1,175 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.adb.connector; + +import io.airlift.configuration.Config; +import io.airlift.configuration.ConfigDescription; + +public class AdbPushdownConfig +{ + private boolean pushdownLiterals = true; + private boolean pushdownDecimalArithmetics = true; + private boolean pushdownDoubleArithmetics = true; + private boolean pushdownDatetimeArithmetics = true; + private boolean pushdownDatetimeComparison = true; + private boolean pushdownFunctionCast = true; + private boolean pushdownFunctionDatePart = true; + private boolean pushdownFunctionLike = true; + private boolean pushdownFunctionSubstring = true; + private boolean pushdownFunctionUpper = true; + private boolean pushdownFunctionLower = true; + + public boolean isPushdownLiterals() + { + return this.pushdownLiterals; + } + + @Config("adb.pushdown.literals") + @ConfigDescription("Whether to pushdown BOOLEAN, CHAR, REAL and DOUBLE literals, as well as literals with NULL values") + public AdbPushdownConfig setPushdownLiterals(boolean pushdownLiterals) + { + this.pushdownLiterals = pushdownLiterals; + return this; + } + + public boolean isPushdownDecimalArithmetics() + { + return this.pushdownDecimalArithmetics; + } + + @Config("adb.pushdown.decimal-arithmetics") + @ConfigDescription("Whether to pushdown arithmetical operations on DECIMAL data type") + public AdbPushdownConfig setPushdownDecimalArithmetics(boolean pushdownDecimalArithmetics) + { + this.pushdownDecimalArithmetics = pushdownDecimalArithmetics; + return this; + } + + public boolean isPushdownDoubleArithmetics() + { + return this.pushdownDoubleArithmetics; + } + + @Config("adb.pushdown.double-arithmetics") + @ConfigDescription("Whether to pushdown arithmetical operations on REAL and DOUBLE data types") + public AdbPushdownConfig setPushdownDoubleArithmetics(boolean pushdownDoubleArithmetics) + { + this.pushdownDoubleArithmetics = pushdownDoubleArithmetics; + return this; + } + + public boolean isPushdownDatetimeArithmetics() + { + return this.pushdownDatetimeArithmetics; + } + + @Config("adb.pushdown.datetime-arithmetics") + @ConfigDescription("Whether to pushdown arithmetical operations on date/time data types") + public AdbPushdownConfig setPushdownDatetimeArithmetics(boolean pushdownDatetimeArithmetics) + { + this.pushdownDatetimeArithmetics = pushdownDatetimeArithmetics; + return this; + } + + public boolean isPushdownDatetimeComparison() + { + return this.pushdownDatetimeComparison; + } + + @Config("adb.pushdown.datetime-comparison") + @ConfigDescription("Whether to pushdown comparison operations on date/time data types") + public AdbPushdownConfig setPushdownDatetimeComparison(boolean pushdownDatetimeComparison) + { + this.pushdownDatetimeComparison = pushdownDatetimeComparison; + return this; + } + + public boolean isPushdownFunctionCast() + { + return this.pushdownFunctionCast; + } + + @Config("adb.pushdown.function.cast") + @ConfigDescription("Whether to pushdown CAST function") + public AdbPushdownConfig setPushdownFunctionCast(boolean pushdownFunctionCast) + { + this.pushdownFunctionCast = pushdownFunctionCast; + return this; + } + + public boolean isPushdownFunctionDatePart() + { + return this.pushdownFunctionDatePart; + } + + @Config("adb.pushdown.function.date-part") + @ConfigDescription("Whether to pushdown DATE_PART functions") + public AdbPushdownConfig setPushdownFunctionDatePart(boolean pushdownFunctionDatePart) + { + this.pushdownFunctionDatePart = pushdownFunctionDatePart; + return this; + } + + public boolean isPushdownFunctionLike() + { + return this.pushdownFunctionLike; + } + + @Config("adb.pushdown.function.like") + @ConfigDescription("Whether to pushdown LIKE function") + public AdbPushdownConfig setPushdownFunctionLike(boolean pushdownFunctionLike) + { + this.pushdownFunctionLike = pushdownFunctionLike; + return this; + } + + public boolean isPushdownFunctionSubstring() + { + return this.pushdownFunctionSubstring; + } + + @Config("adb.pushdown.function.substring") + @ConfigDescription("Whether to pushdown SUBSTRING function") + public AdbPushdownConfig setPushdownFunctionSubstring(boolean pushdownFunctionSubstring) + { + this.pushdownFunctionSubstring = pushdownFunctionSubstring; + return this; + } + + public boolean isPushdownFunctionUpper() + { + return this.pushdownFunctionUpper; + } + + @Config("adb.pushdown.function.upper") + @ConfigDescription("Whether to pushdown UPPER function") + public AdbPushdownConfig setPushdownFunctionUpper(boolean pushdownFunctionUpper) + { + this.pushdownFunctionUpper = pushdownFunctionUpper; + return this; + } + + public boolean isPushdownFunctionLower() + { + return this.pushdownFunctionLower; + } + + @Config("adb.pushdown.function.lower") + @ConfigDescription("Whether to pushdown LOWER function") + public AdbPushdownConfig setPushdownFunctionLower(boolean pushdownFunctionLower) + { + this.pushdownFunctionLower = pushdownFunctionLower; + return this; + } +} diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbPushdownSessionProperties.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbPushdownSessionProperties.java new file mode 100644 index 0000000000000..16cb9fd65c0ee --- /dev/null +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbPushdownSessionProperties.java @@ -0,0 +1,156 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.adb.connector; + +import com.google.common.collect.ImmutableList; +import com.google.inject.Inject; +import io.trino.plugin.base.session.SessionPropertiesProvider; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.session.PropertyMetadata; + +import java.util.List; + +public class AdbPushdownSessionProperties + implements SessionPropertiesProvider +{ + public static final String PUSHDOWN_LITERALS = "pushdown_literals"; + public static final String PUSHDOWN_DECIMAL_ARITHMETICS = "pushdown_decimal_arithmetics"; + public static final String PUSHDOWN_DOUBLE_ARITHMETICS = "pushdown_double_arithmetics"; + public static final String PUSHDOWN_DATETIME_ARITHMETICS = "pushdown_datetime_arithmetics"; + public static final String PUSHDOWN_DATETIME_COMPARISON = "pushdown_datetime_comparison"; + public static final String PUSHDOWN_FUNCTION_CAST = "pushdown_function_cast"; + public static final String PUSHDOWN_FUNCTION_DATE_PART = "pushdown_function_date_part"; + public static final String PUSHDOWN_FUNCTION_LIKE = "pushdown_function_like"; + public static final String PUSHDOWN_FUNCTION_SUBSTRING = "pushdown_function_substring"; + public static final String PUSHDOWN_FUNCTION_UPPER = "pushdown_function_upper"; + public static final String PUSHDOWN_FUNCTION_LOWER = "pushdown_function_lower"; + private final List> sessionProperties; + + @Inject + public AdbPushdownSessionProperties(AdbPushdownConfig config) + { + ImmutableList.Builder> builder = ImmutableList.builder(); + this.sessionProperties = builder + .add(PropertyMetadata.booleanProperty( + PUSHDOWN_LITERALS, + "Whether to pushdown BOOLEAN, CHAR, REAL and DOUBLE literals, as well as literals with NULL values", + config.isPushdownLiterals(), + false)) + .add(PropertyMetadata.booleanProperty(PUSHDOWN_DECIMAL_ARITHMETICS, + "Whether to pushdown arithmetical operations on DECIMAL data type", + config.isPushdownDecimalArithmetics(), + false)) + .add(PropertyMetadata.booleanProperty( + PUSHDOWN_DOUBLE_ARITHMETICS, + "Whether to pushdown arithmetical operations on REAL and DOUBLE data types", + config.isPushdownDoubleArithmetics(), + false)) + .add(PropertyMetadata.booleanProperty( + PUSHDOWN_DATETIME_ARITHMETICS, + "Whether to pushdown arithmetical operations on date/time data types", + config.isPushdownDatetimeArithmetics(), + false)) + .add(PropertyMetadata.booleanProperty( + PUSHDOWN_DATETIME_COMPARISON, + "Whether to pushdown comparison operations on date/time data types", + config.isPushdownDatetimeComparison(), + false)) + .add(PropertyMetadata.booleanProperty(PUSHDOWN_FUNCTION_CAST, + "Whether to pushdown CAST function", + config.isPushdownFunctionCast(), + false)) + .add(PropertyMetadata.booleanProperty(PUSHDOWN_FUNCTION_DATE_PART, + "Whether to pushdown DATE_PART functions", + config.isPushdownFunctionDatePart(), + false)) + .add(PropertyMetadata.booleanProperty(PUSHDOWN_FUNCTION_LIKE, + "Whether to pushdown LIKE function", + config.isPushdownFunctionLike(), + false)) + .add(PropertyMetadata.booleanProperty(PUSHDOWN_FUNCTION_SUBSTRING, + "Whether to pushdown SUBSTRING function", + config.isPushdownFunctionSubstring(), + false)) + .add(PropertyMetadata.booleanProperty(PUSHDOWN_FUNCTION_UPPER, + "Whether to pushdown UPPER function", + config.isPushdownFunctionUpper(), + false)) + .add(PropertyMetadata.booleanProperty(PUSHDOWN_FUNCTION_LOWER, + "Whether to pushdown LOWER function", + config.isPushdownFunctionLower(), + false)) + .build(); + } + + @Override + public List> getSessionProperties() + { + return this.sessionProperties; + } + + public static boolean isPushdownLiterals(ConnectorSession session) + { + return session.getProperty(PUSHDOWN_LITERALS, Boolean.class); + } + + public static boolean isPushdownDecimalArithmetics(ConnectorSession session) + { + return session.getProperty(PUSHDOWN_DECIMAL_ARITHMETICS, Boolean.class); + } + + public static boolean isPushdownDoubleArithmetics(ConnectorSession session) + { + return session.getProperty(PUSHDOWN_DOUBLE_ARITHMETICS, Boolean.class); + } + + public static boolean isPushdownDatetimeArithmetics(ConnectorSession session) + { + return session.getProperty(PUSHDOWN_DATETIME_ARITHMETICS, Boolean.class); + } + + public static boolean isPushdownDatetimeComparison(ConnectorSession session) + { + return session.getProperty(PUSHDOWN_DATETIME_COMPARISON, Boolean.class); + } + + public static boolean isPushdownFunctionCast(ConnectorSession session) + { + return session.getProperty(PUSHDOWN_FUNCTION_CAST, Boolean.class); + } + + public static boolean isPushdownFunctionDatePart(ConnectorSession session) + { + return session.getProperty(PUSHDOWN_FUNCTION_DATE_PART, Boolean.class); + } + + public static boolean isPushdownFunctionLike(ConnectorSession session) + { + return session.getProperty(PUSHDOWN_FUNCTION_LIKE, Boolean.class); + } + + public static boolean isPushdownFunctionSubstring(ConnectorSession session) + { + return session.getProperty(PUSHDOWN_FUNCTION_SUBSTRING, Boolean.class); + } + + public static boolean isPushdownFunctionUpper(ConnectorSession session) + { + return session.getProperty(PUSHDOWN_FUNCTION_UPPER, Boolean.class); + } + + public static boolean isPushdownFunctionLower(ConnectorSession session) + { + return session.getProperty(PUSHDOWN_FUNCTION_LOWER, Boolean.class); + } +} diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbSessionProperties.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbSessionProperties.java index 2660b80833ba7..8e4b9a9809e19 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbSessionProperties.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbSessionProperties.java @@ -28,17 +28,29 @@ public class AdbSessionProperties implements SessionPropertiesProvider { + private static final String ARRAY_MAPPING_PROPERTY = "array_mapping"; + private static final String MAX_SCAN_PARALLELISM_PROPERTY = "max_scan_parallelism"; + private static final String GPFDIST_RETRY_TIMEOUT_PROPERTY = "gpfdist_retry_timeout"; + private static final String ENABLE_STRING_PUSHDOWN_WITH_COLLATE_PROPERTY = "enable_string_pushdown_with_collate"; private final List> sessionProperties; @Inject public AdbSessionProperties(AdbPluginConfig config) { this.sessionProperties = ImmutableList.of( - PropertyMetadata.enumProperty("array_mapping", "Handling of PostgreSql arrays", AdbPluginConfig.ArrayMapping.class, config.getArrayMapping(), false), + PropertyMetadata.enumProperty(ARRAY_MAPPING_PROPERTY, "Handling of PostgreSql arrays", + AdbPluginConfig.ArrayMapping.class, config.getArrayMapping(), false), PropertyMetadata.integerProperty( - "max_scan_parallelism", "Maximum degree of parallelism when scanning tables. Defaults to 1.", config.getMaxScanParallelism(), false), + MAX_SCAN_PARALLELISM_PROPERTY, + "Maximum degree of parallelism when scanning tables. Defaults to 1.", + config.getMaxScanParallelism(), false), + PropertyMetadata.booleanProperty( + ENABLE_STRING_PUSHDOWN_WITH_COLLATE_PROPERTY, + "Enable string pushdown with collate (experimental). Default false", + config.isEnableStringPushdownWithCollate(), false), PropertyMetadataUtil.durationProperty( - "gpfdist_retry_timeout", "Value of adb gpfdist_retry_timeout property", config.getGpfdistRetryTimeout(), false)); + GPFDIST_RETRY_TIMEOUT_PROPERTY, "Value of adb gpfdist_retry_timeout property", + config.getGpfdistRetryTimeout(), false)); } @Override @@ -49,21 +61,21 @@ public List> getSessionProperties() public static AdbPluginConfig.ArrayMapping getArrayMapping(ConnectorSession session) { - return session.getProperty("array_mapping", AdbPluginConfig.ArrayMapping.class); + return session.getProperty(ARRAY_MAPPING_PROPERTY, AdbPluginConfig.ArrayMapping.class); } public static boolean isEnableStringPushdownWithCollate(ConnectorSession session) { - return session.getProperty("enable_string_pushdown_with_collate", Boolean.class); + return session.getProperty(ENABLE_STRING_PUSHDOWN_WITH_COLLATE_PROPERTY, Boolean.class); } public static int getMaxScanParallelism(ConnectorSession session) { - return session.getProperty("max_scan_parallelism", Integer.class); + return session.getProperty(MAX_SCAN_PARALLELISM_PROPERTY, Integer.class); } public static Optional getGpfdistRetryTimeout(ConnectorSession session) { - return Optional.ofNullable(session.getProperty("gpfdist_retry_timeout", Duration.class)); + return Optional.ofNullable(session.getProperty(GPFDIST_RETRY_TIMEOUT_PROPERTY, Duration.class)); } } diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbSqlClient.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbSqlClient.java index 5102909979276..2703e6691efae 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbSqlClient.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbSqlClient.java @@ -19,8 +19,20 @@ import com.google.inject.Inject; import io.airlift.log.Logger; import io.trino.plugin.adb.AdbPluginConfig; +import io.trino.plugin.adb.connector.aggregation.AdbImplementAvgBigint; +import io.trino.plugin.adb.connector.aggregation.AdbImplementMinMax; +import io.trino.plugin.adb.connector.aggregation.AdbImplementStddevPop; +import io.trino.plugin.adb.connector.aggregation.AdbImplementStddevSamp; +import io.trino.plugin.adb.connector.aggregation.AdbImplementVariancePop; +import io.trino.plugin.adb.connector.aggregation.AdbImplementVarianceSamp; import io.trino.plugin.adb.connector.datatype.ColumnDataType; import io.trino.plugin.adb.connector.datatype.mapper.DataTypeMapper; +import io.trino.plugin.adb.connector.expression.AdbRewriteBooleanConstant; +import io.trino.plugin.adb.connector.expression.AdbRewriteCast; +import io.trino.plugin.adb.connector.expression.AdbRewriteCharConstant; +import io.trino.plugin.adb.connector.expression.AdbRewriteDatetimeArithmetics; +import io.trino.plugin.adb.connector.expression.AdbRewriteInexactNumericConstant; +import io.trino.plugin.adb.connector.expression.AdbRewriteNullConstant; import io.trino.plugin.adb.connector.metadata.AdbMetadataDao; import io.trino.plugin.adb.connector.table.AdbTableDistributed; import io.trino.plugin.adb.connector.table.AdbTableProperties; @@ -107,6 +119,7 @@ import java.util.OptionalLong; import java.util.StringJoiner; import java.util.function.BiFunction; +import java.util.function.Predicate; import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; @@ -169,57 +182,145 @@ public AdbSqlClient(ConnectionFactory connectionFactory, this.metadata = metadata; this.fetchSize = Optional.empty(); this.pluginConfig = config; - connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() + connectorExpressionRewriter = createExpressionRewriter(dataTypeMapper); + JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, + Optional.of("bigint"), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty()); + this.aggregateFunctionRewriter = createAggregationFunctionRewriter(bigintTypeHandle); + } + + private ConnectorExpressionRewriter createExpressionRewriter(DataTypeMapper dataTypeMapper) + { + Predicate pushdownWithCollateEnabled = + AdbSessionProperties::isEnableStringPushdownWithCollate; + Predicate datetimeComparisonEnabled = + AdbPushdownSessionProperties::isPushdownDatetimeComparison; + Predicate decimalArithmeticsEnabled = + AdbPushdownSessionProperties::isPushdownDecimalArithmetics; + Predicate doubleArithmeticsEnabled = + AdbPushdownSessionProperties::isPushdownDoubleArithmetics; + Predicate functionDatePartEnabled = AdbPushdownSessionProperties::isPushdownFunctionDatePart; + Predicate functionLikeEnabled = AdbPushdownSessionProperties::isPushdownFunctionLike; + Predicate functionSubstringEnabled = + AdbPushdownSessionProperties::isPushdownFunctionSubstring; + Predicate functionUpperEnabled = AdbPushdownSessionProperties::isPushdownFunctionUpper; + Predicate functionLowerEnabled = AdbPushdownSessionProperties::isPushdownFunctionLower; + return JdbcConnectorExpressionRewriterBuilder.newBuilder() .add(new RewriteVariable(this::quoted)) .add(new RewriteVarcharConstant()) .add(new RewriteExactNumericConstant()) + .add(new AdbRewriteInexactNumericConstant()) + .add(new AdbRewriteBooleanConstant()) + .add(new AdbRewriteCharConstant()) + .add(new AdbRewriteNullConstant(dataTypeMapper)) .withTypeClass("integer_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint")) .withTypeClass("decimal_type", ImmutableSet.of("decimal")) .withTypeClass("double_type", ImmutableSet.of("real", "double")) - .withTypeClass("numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal", "real", "double")) + .withTypeClass("numeric_type", + ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal", "real", "double")) .withTypeClass("string_type", ImmutableSet.of("char", "varchar")) - .withTypeClass("datetime_type", ImmutableSet.of("date", "time", "timestamp", "timestamp with time zone")) + .withTypeClass("datetime_type", + ImmutableSet.of("date", "time", "timestamp", "timestamp with time zone")) .add(new RewriteAnd()) .add(new RewriteOr()) - .map("$not($is_null(value))") - .to("value IS NOT NULL") - .map("$not(value: boolean)") - .to("NOT value") + .map("$not($is_null(value))").to("value IS NOT NULL") + .map("$not(value: boolean)").to("NOT value") .add(new RewriteIn()) - .map("$is_null(value)") - .to("value IS NULL") - .map("$nullif(first, second)") - .to("NULLIF(first, second)") - .map("$equal(left, right)") - .to("left = right") - .map("$not_equal(left, right)") - .to("left <> right") - .map("$is_distinct_from(left, right)") - .to("left IS DISTINCT FROM right") - .map("$less_than(left: numeric_type, right: numeric_type)") + .map("$is_null(value)").to("value IS NULL") + .map("$nullif(first, second)").to("NULLIF(first, second)") + .map("$equal(left, right)").to("left = right") + .map("$not_equal(left, right)").to("left <> right") + .map("$is_distinct_from(left, right)").to("left IS DISTINCT FROM right") + .map("$less_than(left: numeric_type, right: numeric_type)").to("left < right") + .map("$less_than_or_equal(left: numeric_type, right: numeric_type)").to("left <= right") + .map("$greater_than(left: numeric_type, right: numeric_type)").to("left > right") + .map("$greater_than_or_equal(left: numeric_type, right: numeric_type)").to("left >= right") + .when(pushdownWithCollateEnabled).map("$less_than(left: string_type, right: string_type)") + .to("left < right COLLATE \"C\"") + .when(pushdownWithCollateEnabled).map("$less_than_or_equal(left: string_type, right: string_type)") + .to("left <= right COLLATE \"C\"") + .when(pushdownWithCollateEnabled).map("$greater_than(left: string_type, right: string_type)") + .to("left > right COLLATE \"C\"") + .when(pushdownWithCollateEnabled).map("$greater_than_or_equal(left: string_type, right: string_type)") + .to("left >= right COLLATE \"C\"") + .when(datetimeComparisonEnabled).map("$less_than(left: datetime_type, right: datetime_type)") .to("left < right") - .map("$less_than_or_equal(left: numeric_type, right: numeric_type)") + .when(datetimeComparisonEnabled).map("$less_than_or_equal(left: datetime_type, right: datetime_type)") .to("left <= right") - .map("$greater_than(left: numeric_type, right: numeric_type)") + .when(datetimeComparisonEnabled).map("$greater_than(left: datetime_type, right: datetime_type)") .to("left > right") - .map("$greater_than_or_equal(left: numeric_type, right: numeric_type)") - .to("left >= right") + .when(datetimeComparisonEnabled) + .map("$greater_than_or_equal(left: datetime_type, right: datetime_type)").to("left >= right") + .map("$add(left: integer_type, right: integer_type)").to("left + right") + .map("$subtract(left: integer_type, right: integer_type)").to("left - right") + .map("$multiply(left: integer_type, right: integer_type)").to("left * right") + .map("$divide(left: integer_type, right: integer_type)").to("left / right") + .map("$modulus(left: integer_type, right: integer_type)").to("left % right") + .map("$negate(value: integer_type)").to("-value") + .when(decimalArithmeticsEnabled).map("$add(left: decimal_type, right: decimal_type)").to("left + right") + .when(decimalArithmeticsEnabled).map("$subtract(left: decimal_type, right: decimal_type)") + .to("left - right") + .when(decimalArithmeticsEnabled).map("$multiply(left: decimal_type, right: decimal_type)") + .to("left * right") + .when(decimalArithmeticsEnabled).map("$divide(left: decimal_type, right: decimal_type)") + .to("left / right") + .when(decimalArithmeticsEnabled).map("$modulus(left: decimal_type, right: decimal_type)") + .to("left % right") + .when(decimalArithmeticsEnabled).map("$negate(value: decimal_type)").to("-value") + .when(doubleArithmeticsEnabled).map("$add(left: double_type, right: double_type)").to("left + right") + .when(doubleArithmeticsEnabled).map("$subtract(left: double_type, right: double_type)") + .to("left - right") + .when(doubleArithmeticsEnabled).map("$multiply(left: double_type, right: double_type)") + .to("left * right") + .when(doubleArithmeticsEnabled).map("$divide(left: double_type, right: double_type)").to("left / right") + .when(doubleArithmeticsEnabled).map("$modulus(left: double_type, right: double_type)") + .to("left % right") + .when(doubleArithmeticsEnabled).map("$negate(value: double_type)").to("-value") + .add(new AdbRewriteDatetimeArithmetics()) + .add(new AdbRewriteCast(dataTypeMapper)) + .when(functionDatePartEnabled).map("year(arg: timestamp)").to("DATE_PART('isoyear', arg)") + .when(functionDatePartEnabled).map("quarter(arg: timestamp)").to("DATE_PART('quarter', arg)") + .when(functionDatePartEnabled).map("month(arg: timestamp)").to("DATE_PART('month', arg)") + .when(functionDatePartEnabled).map("week(arg: timestamp)").to("DATE_PART('week', arg)") + .when(functionDatePartEnabled).map("day(arg: timestamp)").to("DATE_PART('day', arg)") + .when(functionDatePartEnabled).map("day_of_week(arg: timestamp)").to("DATE_PART('isodow', arg)") + .when(functionDatePartEnabled).map("day_of_year(arg: timestamp)").to("DATE_PART('doy', arg)") + .when(functionDatePartEnabled).map("hour(arg: timestamp)").to("DATE_PART('hour', arg)") + .when(functionDatePartEnabled).map("minute(arg: timestamp)").to("DATE_PART('minute', arg)") + .when(functionLikeEnabled).map("$like(value: string_type, pattern): boolean").to("value LIKE pattern") + .when(functionLikeEnabled).map("$like(value: string_type, pattern, escape): boolean") + .to("value LIKE pattern ESCAPE escape") + .when(functionSubstringEnabled).map("substring(arg1: string_type, arg2: integer_type)") + .to("SUBSTRING(arg1 FROM CAST(arg2 AS INT))") + .when(functionSubstringEnabled) + .map("substring(arg1: string_type, arg2: integer_type, arg3: integer_type)") + .to("SUBSTRING(arg1 FROM CAST(arg2 AS INT) FOR CAST(arg3 AS INT))") + .when(functionUpperEnabled).map("upper(arg)").to("UPPER(arg)") + .when(functionLowerEnabled).map("lower(arg)").to("LOWER(arg)") .build(); - JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, - Optional.of("bigint"), - Optional.empty(), - Optional.empty(), - Optional.empty(), - Optional.empty()); - this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>( + } + + private AggregateFunctionRewriter createAggregationFunctionRewriter( + JdbcTypeHandle bigintTypeHandle) + { + return new AggregateFunctionRewriter<>( this.connectorExpressionRewriter, ImmutableSet.>builder() .add(new ImplementCountAll(bigintTypeHandle)) .add(new ImplementCount(bigintTypeHandle)) - .add(new ImplementCountDistinct(bigintTypeHandle, false)) + .add(new ImplementCountDistinct(bigintTypeHandle, true)) .add(new ImplementSum(AdbSqlClient::toDecimalTypeToTypeHandle)) .add(new ImplementAvgFloatingPoint()) .add(new ImplementAvgDecimal()) + .add(new AdbImplementMinMax()) + .add(new AdbImplementAvgBigint()) + .add(new AdbImplementStddevSamp()) + .add(new AdbImplementStddevPop()) + .add(new AdbImplementVarianceSamp()) + .add(new AdbImplementVariancePop()) .add(new ImplementCovarianceSamp()) .add(new ImplementCovariancePop()) .add(new ImplementCorr()) @@ -325,7 +426,8 @@ protected JdbcOutputTableHandle createTable( Optional pageSinkIdColumnName = Optional.empty(); if (pageSinkIdColumn.isPresent()) { - String columnName = identifierMapping.toRemoteColumnName(remoteIdentifiers, pageSinkIdColumn.get().getName()); + String columnName = + identifierMapping.toRemoteColumnName(remoteIdentifiers, pageSinkIdColumn.get().getName()); pageSinkIdColumnName = Optional.of(columnName); this.verifyColumnName(connection.getMetaData(), columnName); columnList.add(this.getColumnDefinitionSql(session, pageSinkIdColumn.get(), columnName)); @@ -371,23 +473,28 @@ private List createTableSqls( RemoteIdentifiers remoteIdentifiers) { ImmutableList.Builder createTableSqlsBuilder = ImmutableList.builder(); - String createTableSql = format("CREATE TABLE %s (%s)", this.quoted(remoteTableName), String.join(", ", columns)); + String createTableSql = + format("CREATE TABLE %s (%s)", this.quoted(remoteTableName), String.join(", ", columns)); Map tableProperties = tableMetadata.getProperties(); Map storageProperties = new HashMap<>(); Optional appendOptimized = AdbTableProperties.getAppendOptimized(tableProperties); - appendOptimized.ifPresent(aBoolean -> storageProperties.put("appendoptimized", aBoolean.toString().toUpperCase(Locale.ENGLISH))); + appendOptimized.ifPresent( + aBoolean -> storageProperties.put("appendoptimized", aBoolean.toString().toUpperCase(Locale.ENGLISH))); Optional blockSize = AdbTableProperties.getBlockSize(tableProperties); blockSize.ifPresent(integer -> storageProperties.put("blocksize", integer.toString())); Optional orientation = AdbTableProperties.getOrientation(tableProperties); - orientation.ifPresent(adbTableStorageOrientation -> storageProperties.put("orientation", adbTableStorageOrientation.name())); + orientation.ifPresent( + adbTableStorageOrientation -> storageProperties.put("orientation", adbTableStorageOrientation.name())); Optional checksum = AdbTableProperties.getChecksum(tableProperties); - checksum.ifPresent(aBoolean -> storageProperties.put("checksum", aBoolean.toString().toUpperCase(Locale.ENGLISH))); + checksum.ifPresent( + aBoolean -> storageProperties.put("checksum", aBoolean.toString().toUpperCase(Locale.ENGLISH))); Optional compressType = AdbTableProperties.getCompressType(tableProperties); - compressType.ifPresent(adbTableStorageCompressType -> storageProperties.put("compresstype", adbTableStorageCompressType.name())); + compressType.ifPresent(adbTableStorageCompressType -> storageProperties.put("compresstype", + adbTableStorageCompressType.name())); Optional compressLevel = AdbTableProperties.getCompressLevel(tableProperties); compressLevel.ifPresent(integer -> storageProperties.put("compresslevel", integer.toString())); @@ -425,11 +532,13 @@ private List createTableSqls( StringJoiner distributionColumnJoiner = new StringJoiner(", ", " DISTRIBUTED BY (", ")"); for (String distributionColumn : distributedBy.get()) { - String remoteDistributionColumn = this.getIdentifierMapping().toRemoteColumnName(remoteIdentifiers, distributionColumn); + String remoteDistributionColumn = + this.getIdentifierMapping().toRemoteColumnName(remoteIdentifiers, distributionColumn); if (!columnNames.contains(remoteDistributionColumn)) { throw new TrinoException( StandardErrorCode.INVALID_TABLE_PROPERTY, - format("%s property references non-existent column %s", "distributed_by", distributionColumn)); + format("%s property references non-existent column %s", "distributed_by", + distributionColumn)); } distributionColumnJoiner.add(remoteDistributionColumn); @@ -489,7 +598,8 @@ protected void renameTable( this.execute( session, connection, - format("ALTER TABLE %s RENAME TO %s", this.quoted(catalogName, remoteSchemaName, remoteTableName), this.quoted(newRemoteTableName))); + format("ALTER TABLE %s RENAME TO %s", this.quoted(catalogName, remoteSchemaName, remoteTableName), + this.quoted(newRemoteTableName))); } } @@ -538,14 +648,16 @@ public List getColumns(ConnectorSession session, JdbcTableHand allColumns++; String columnName = resultSet.getString("COLUMN_NAME"); JdbcTypeHandle typeHandle = new JdbcTypeHandle( - getInteger(resultSet, "DATA_TYPE").orElseThrow(() -> new IllegalStateException("DATA_TYPE is null")), + getInteger(resultSet, "DATA_TYPE").orElseThrow( + () -> new IllegalStateException("DATA_TYPE is null")), Optional.of(resultSet.getString("TYPE_NAME")), getInteger(resultSet, "COLUMN_SIZE"), getInteger(resultSet, "DECIMAL_DIGITS"), Optional.ofNullable(arrayColumnDimensions.get(columnName)), Optional.empty()); Optional columnMapping = toColumnMapping(session, connection, typeHandle); - log.debug("Mapping data type of '%s' column '%s': %s mapped to %s", schemaTableName, columnName, typeHandle, columnMapping); + log.debug("Mapping data type of '%s' column '%s': %s mapped to %s", schemaTableName, columnName, + typeHandle, columnMapping); // skip unsupported column types if (columnMapping.isPresent()) { boolean nullable = (resultSet.getInt("NULLABLE") != columnNoNulls); @@ -571,7 +683,8 @@ public List getColumns(ConnectorSession session, JdbcTableHand // A table may have no supported columns. In rare cases a table might have no columns at all. throw new TableNotFoundException( schemaTableName, - format("Table '%s' has no supported columns (all %s columns are not supported)", schemaTableName, allColumns)); + format("Table '%s' has no supported columns (all %s columns are not supported)", + schemaTableName, allColumns)); } return ImmutableList.copyOf(columns); } @@ -608,7 +721,8 @@ private static Map getArrayColumnDimensions(Connection connecti } @Override - public Optional toColumnMapping(ConnectorSession session, Connection connection, JdbcTypeHandle typeHandle) + public Optional toColumnMapping(ConnectorSession session, Connection connection, + JdbcTypeHandle typeHandle) { return dataTypeMapper.toColumnMapping(session, connection, typeHandle); } @@ -625,13 +739,15 @@ public WriteMapping toWriteMapping(ConnectorSession session, Type type) } @Override - public Optional implementAggregation(ConnectorSession session, AggregateFunction aggregate, Map assignments) + public Optional implementAggregation(ConnectorSession session, AggregateFunction aggregate, + Map assignments) { return aggregateFunctionRewriter.rewrite(session, aggregate, assignments); } @Override - public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) + public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, + Map assignments) { return connectorExpressionRewriter.rewrite(session, expression, assignments); } @@ -662,7 +778,8 @@ protected Optional topNFunction() if (isCollatable(sortItem.column())) { collation = "COLLATE \"C\""; } - return format("%s %s %s %s", quoted(sortItem.column().getColumnName()), collation, ordering, nullsHandling); + return format("%s %s %s %s", quoted(sortItem.column().getColumnName()), collation, ordering, + nullsHandling); }) .collect(joining(", ")); return format("%s ORDER BY %s LIMIT %d", query, orderBy, limit); @@ -673,7 +790,8 @@ public static boolean isCollatable(JdbcColumnHandle column) { if (column.getColumnType() instanceof CharType || column.getColumnType() instanceof VarcharType) { String jdbcTypeName = column.getJdbcTypeHandle().jdbcTypeName() - .orElseThrow(() -> new TrinoException(JDBC_ERROR, "Type name is missing: " + column.getJdbcTypeHandle())); + .orElseThrow(() -> new TrinoException(JDBC_ERROR, + "Type name is missing: " + column.getJdbcTypeHandle())); return isCollatable(jdbcTypeName); } // non-textual types don't have the concept of collation @@ -710,7 +828,8 @@ public OptionalLong delete(ConnectorSession session, JdbcTableHandle handle) checkArgument(handle.isNamedRelation(), "Unable to delete from synthetic table: %s", handle); checkArgument(handle.getLimit().isEmpty(), "Unable to delete when limit is set: %s", handle); checkArgument(handle.getSortOrder().isEmpty(), "Unable to delete when sort order is set: %s", handle); - checkArgument(handle.getUpdateAssignments().isEmpty(), "Unable to delete when update assignments are set: %s", handle); + checkArgument(handle.getUpdateAssignments().isEmpty(), "Unable to delete when update assignments are set: %s", + handle); try (Connection connection = connectionFactory.openConnection(session)) { verify(connection.getAutoCommit()); PreparedQuery preparedQuery = queryBuilder.prepareDeleteQuery( @@ -720,7 +839,8 @@ public OptionalLong delete(ConnectorSession session, JdbcTableHandle handle) handle.getRequiredNamedRelation(), handle.getConstraint(), getAdditionalPredicate(handle.getConstraintExpressions(), Optional.empty())); - try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(this, session, connection, preparedQuery, Optional.empty())) { + try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(this, session, connection, + preparedQuery, Optional.empty())) { int affectedRowsCount = preparedStatement.executeUpdate(); // In getPreparedStatement we set autocommit to false so here we need an explicit commit connection.commit(); @@ -738,7 +858,8 @@ public OptionalLong update(ConnectorSession session, JdbcTableHandle handle) checkArgument(handle.isNamedRelation(), "Unable to update from synthetic table: %s", handle); checkArgument(handle.getLimit().isEmpty(), "Unable to update when limit is set: %s", handle); checkArgument(handle.getSortOrder().isEmpty(), "Unable to update when sort order is set: %s", handle); - checkArgument(!handle.getUpdateAssignments().isEmpty(), "Unable to update when update assignments are not set: %s", handle); + checkArgument(!handle.getUpdateAssignments().isEmpty(), + "Unable to update when update assignments are not set: %s", handle); try (Connection connection = connectionFactory.openConnection(session)) { verify(connection.getAutoCommit()); PreparedQuery preparedQuery = queryBuilder.prepareUpdateQuery( @@ -749,7 +870,8 @@ public OptionalLong update(ConnectorSession session, JdbcTableHandle handle) handle.getConstraint(), getAdditionalPredicate(handle.getConstraintExpressions(), Optional.empty()), handle.getUpdateAssignments()); - try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(this, session, connection, preparedQuery, Optional.empty())) { + try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(this, session, connection, + preparedQuery, Optional.empty())) { int affectedRows = preparedStatement.executeUpdate(); connection.commit(); return OptionalLong.of(affectedRows); @@ -793,7 +915,8 @@ public Optional implementJoin( leftSource, rightSource, statistics, - () -> super.implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, joinConditions, statistics)); + () -> super.implementJoin(session, joinType, leftSource, leftProjections, rightSource, rightProjections, + joinConditions, statistics)); } @Override @@ -817,7 +940,8 @@ public Optional legacyImplementJoin( leftSource, rightSource, statistics, - () -> super.legacyImplementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics)); + () -> super.legacyImplementJoin(session, joinType, leftSource, rightSource, joinConditions, + rightAssignments, leftAssignments, statistics)); } @Override @@ -843,7 +967,9 @@ protected void verifySchemaName(DatabaseMetaData databaseMetadata, String schema throws SQLException { if (schemaName.length() > databaseMetadata.getMaxSchemaNameLength()) { - throw new TrinoException(NOT_SUPPORTED, format("Schema name must be shorter than or equal to '%s' characters but got '%s'", databaseMetadata.getMaxSchemaNameLength(), schemaName.length())); + throw new TrinoException(NOT_SUPPORTED, + format("Schema name must be shorter than or equal to '%s' characters but got '%s'", + databaseMetadata.getMaxSchemaNameLength(), schemaName.length())); } } @@ -852,7 +978,9 @@ protected void verifyTableName(DatabaseMetaData databaseMetadata, String tableNa throws SQLException { if (tableName.length() > databaseMetadata.getMaxTableNameLength()) { - throw new TrinoException(NOT_SUPPORTED, format("Table name must be shorter than or equal to '%s' characters but got '%s'", databaseMetadata.getMaxTableNameLength(), tableName.length())); + throw new TrinoException(NOT_SUPPORTED, + format("Table name must be shorter than or equal to '%s' characters but got '%s'", + databaseMetadata.getMaxTableNameLength(), tableName.length())); } } @@ -863,12 +991,15 @@ protected void verifyColumnName(DatabaseMetaData databaseMetadata, String column // PostgreSQL truncates table name to 63 chars silently // PostgreSQL driver caches the max column name length in a DatabaseMetaData object. The cost to call this method per column is low. if (columnName.length() > databaseMetadata.getMaxColumnNameLength()) { - throw new TrinoException(NOT_SUPPORTED, format("Column name must be shorter than or equal to '%s' characters but got '%s': '%s'", databaseMetadata.getMaxColumnNameLength(), columnName.length(), columnName)); + throw new TrinoException(NOT_SUPPORTED, + format("Column name must be shorter than or equal to '%s' characters but got '%s': '%s'", + databaseMetadata.getMaxColumnNameLength(), columnName.length(), columnName)); } } @Override - public void setColumnComment(ConnectorSession session, JdbcTableHandle handle, JdbcColumnHandle column, Optional comment) + public void setColumnComment(ConnectorSession session, JdbcTableHandle handle, JdbcColumnHandle column, + Optional comment) { // adb doesn't support prepared statement for COMMENT statement String sql = format( diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbBaseImplementStddevVariance.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbBaseImplementStddevVariance.java new file mode 100644 index 0000000000000..3270c3d95d56f --- /dev/null +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbBaseImplementStddevVariance.java @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.adb.connector.aggregation; + +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.base.aggregation.AggregateFunctionPatterns; +import io.trino.plugin.base.aggregation.AggregateFunctionRule; +import io.trino.plugin.base.expression.ConnectorExpressionPatterns; +import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; +import io.trino.spi.connector.AggregateFunction; +import io.trino.spi.expression.Variable; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.RealType; + +import java.util.Optional; + +import static com.google.common.base.Verify.verify; +import static java.lang.String.format; +import static java.sql.Types.DOUBLE; + +public abstract class AdbBaseImplementStddevVariance + implements AggregateFunctionRule +{ + private static final JdbcTypeHandle DOUBLE_TYPE_HANDLE = new JdbcTypeHandle( + DOUBLE, + Optional.of("double"), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty()); + private static final Capture ARGUMENT = Capture.newCapture(); + private final String trinoFunctionName; + private final String greenplumFunctionName; + + protected AdbBaseImplementStddevVariance(String trinoFunctionName, String greenplumFunctionName) + { + this.trinoFunctionName = trinoFunctionName; + this.greenplumFunctionName = greenplumFunctionName; + } + + @Override + public Pattern getPattern() + { + return AggregateFunctionPatterns.basicAggregation() + .with(AggregateFunctionPatterns.functionName().equalTo(trinoFunctionName)) + .with(AggregateFunctionPatterns.singleArgument() + .matching(ConnectorExpressionPatterns.variable().capturedAs(ARGUMENT))); + } + + @Override + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, + RewriteContext context) + { + Variable argument = captures.get(ARGUMENT); + verify(aggregateFunction.getOutputType() == DoubleType.DOUBLE); + ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow(); + String arg = rewrittenArgument.expression(); + if (argument.getType() != RealType.REAL && argument.getType() != DoubleType.DOUBLE) { + arg = format("CAST(%s as double precision)", arg); + } + return Optional.of( + new JdbcExpression(format("%s(%s)", greenplumFunctionName, arg), rewrittenArgument.parameters(), + DOUBLE_TYPE_HANDLE)); + } +} diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementAvgBigint.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementAvgBigint.java new file mode 100644 index 0000000000000..4b59e9c92efe4 --- /dev/null +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementAvgBigint.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.adb.connector.aggregation; + +import io.trino.plugin.jdbc.aggregation.BaseImplementAvgBigint; + +public class AdbImplementAvgBigint + extends BaseImplementAvgBigint +{ + @Override + protected String getRewriteFormatExpression() + { + return "avg(CAST(%s AS double precision))"; + } +} diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementMinMax.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementMinMax.java new file mode 100644 index 0000000000000..394de7d3877aa --- /dev/null +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementMinMax.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.adb.connector.aggregation; + +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.base.aggregation.AggregateFunctionPatterns; +import io.trino.plugin.base.aggregation.AggregateFunctionRule; +import io.trino.plugin.base.expression.ConnectorExpressionPatterns; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; +import io.trino.spi.connector.AggregateFunction; +import io.trino.spi.expression.Variable; +import io.trino.spi.type.CharType; +import io.trino.spi.type.VarcharType; + +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.Verify.verify; +import static java.lang.String.format; + +public class AdbImplementMinMax + implements AggregateFunctionRule +{ + private static final Capture ARGUMENT = Capture.newCapture(); + + @Override + public Pattern getPattern() + { + return AggregateFunctionPatterns.basicAggregation() + .with(AggregateFunctionPatterns.functionName().matching(Set.of("min", "max")::contains)) + .with(AggregateFunctionPatterns.singleArgument() + .matching(ConnectorExpressionPatterns.variable().capturedAs(ARGUMENT))); + } + + @Override + public Optional rewrite(AggregateFunction aggregateFunction, + Captures captures, + RewriteContext context) + { + Variable argument = captures.get(ARGUMENT); + JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName()); + verify(columnHandle.getColumnType().equals(aggregateFunction.getOutputType())); + Optional suffix = Optional.empty(); + if (columnHandle.getColumnType() instanceof CharType || columnHandle.getColumnType() instanceof VarcharType) { + suffix = Optional.of(" COLLATE \"C\""); + } + ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow(); + return Optional.of( + new JdbcExpression( + format("%s(%s%s)", aggregateFunction.getFunctionName(), rewrittenArgument.expression(), + suffix.orElse("")), + rewrittenArgument.parameters(), + columnHandle.getJdbcTypeHandle())); + } +} diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementStddevPop.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementStddevPop.java new file mode 100644 index 0000000000000..8364950aa0945 --- /dev/null +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementStddevPop.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.adb.connector.aggregation; + +public class AdbImplementStddevPop + extends AdbBaseImplementStddevVariance +{ + public AdbImplementStddevPop() + { + super("stddev_pop", "stddev_pop"); + } +} diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementStddevSamp.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementStddevSamp.java new file mode 100644 index 0000000000000..827472661940d --- /dev/null +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementStddevSamp.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.adb.connector.aggregation; + +public class AdbImplementStddevSamp + extends AdbBaseImplementStddevVariance +{ + public AdbImplementStddevSamp() + { + super("stddev", "stddev_samp"); + } +} diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementVariancePop.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementVariancePop.java new file mode 100644 index 0000000000000..aa8da22af95e4 --- /dev/null +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementVariancePop.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.adb.connector.aggregation; + +public class AdbImplementVariancePop + extends AdbBaseImplementStddevVariance +{ + public AdbImplementVariancePop() + { + super("var_pop", "var_pop"); + } +} diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementVarianceSamp.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementVarianceSamp.java new file mode 100644 index 0000000000000..ce12b7045fd6f --- /dev/null +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementVarianceSamp.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.adb.connector.aggregation; + +public class AdbImplementVarianceSamp + extends AdbBaseImplementStddevVariance +{ + public AdbImplementVarianceSamp() + { + super("variance", "var_samp"); + } +} diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/datatype/VarcharDataType.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/datatype/VarcharDataType.java index b0a14640da154..8b814767df95d 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/datatype/VarcharDataType.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/datatype/VarcharDataType.java @@ -23,7 +23,8 @@ public class VarcharDataType public VarcharDataType(VarcharType varcharType) { - this.name = varcharType.isUnbounded() ? "varchar" : String.format("varchar(%d)", varcharType.getBoundedLength()); + this.name = + varcharType.isUnbounded() ? "varchar" : String.format("varchar(%d)", varcharType.getBoundedLength()); this.varcharType = varcharType; } diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/datatype/mapper/DataTypeMapperImpl.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/datatype/mapper/DataTypeMapperImpl.java index f25c3a6097325..1d68707bd4e17 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/datatype/mapper/DataTypeMapperImpl.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/datatype/mapper/DataTypeMapperImpl.java @@ -18,7 +18,6 @@ import com.google.inject.Inject; import io.airlift.slice.Slice; import io.trino.plugin.adb.AdbPluginConfig; -import io.trino.plugin.adb.connector.AdbColumnMapping; import io.trino.plugin.adb.connector.AdbSessionProperties; import io.trino.plugin.adb.connector.datatype.BigintDataType; import io.trino.plugin.adb.connector.datatype.BitDataType; @@ -43,6 +42,7 @@ import io.trino.plugin.adb.connector.datatype.UnknownDataType; import io.trino.plugin.adb.connector.datatype.UuidDataType; import io.trino.plugin.adb.connector.datatype.VarcharDataType; +import io.trino.plugin.adb.connector.table.AdbColumnMapping; import io.trino.plugin.jdbc.BaseJdbcConfig; import io.trino.plugin.jdbc.BooleanReadFunction; import io.trino.plugin.jdbc.ColumnMapping; @@ -202,7 +202,8 @@ else if (AdbSessionProperties.isEnableStringPushdownWithCollate(session)) { return PredicatePushdownController.FULL_PUSHDOWN.apply(session, domain); } else { - Domain simplifiedDomain = domain.simplify(JdbcMetadataSessionProperties.getDomainCompactionThreshold(session)); + Domain simplifiedDomain = + domain.simplify(JdbcMetadataSessionProperties.getDomainCompactionThreshold(session)); return !simplifiedDomain.getValues().isDiscreteSet() ? DISABLE_PUSHDOWN.apply(session, domain) : PredicatePushdownController.FULL_PUSHDOWN.apply(session, simplifiedDomain); @@ -220,7 +221,8 @@ public DataTypeMapperImpl(TypeManager typeManager, BaseJdbcConfig jdbcConfig) } @Override - public Optional toColumnMapping(ConnectorSession session, Connection connection, JdbcTypeHandle typeHandle) + public Optional toColumnMapping(ConnectorSession session, Connection connection, + JdbcTypeHandle typeHandle) { return Optional.ofNullable(toColumnMappingInternal(session, Optional.of(connection), typeHandle)) .map(AdbColumnMapping::columnMapping); @@ -278,14 +280,17 @@ public WriteMapping toWriteMapping(ConnectorSession session, Type type) } if (type instanceof TimeType timeType) { verify(timeType.getPrecision() <= POSTGRESQL_MAX_SUPPORTED_TIMESTAMP_PRECISION); - return WriteMapping.longMapping(format("time(%s)", timeType.getPrecision()), timeWriteFunction(timeType.getPrecision())); + return WriteMapping.longMapping(format("time(%s)", timeType.getPrecision()), + timeWriteFunction(timeType.getPrecision())); } if (type instanceof TimestampType timestampType) { if (timestampType.getPrecision() <= POSTGRESQL_MAX_SUPPORTED_TIMESTAMP_PRECISION) { - return WriteMapping.longMapping(format("timestamp(%s)", timestampType.getPrecision()), DataTypeMapperImpl::shortTimestampWriteFunction); + return WriteMapping.longMapping(format("timestamp(%s)", timestampType.getPrecision()), + DataTypeMapperImpl::shortTimestampWriteFunction); } else { - return WriteMapping.objectMapping(format("timestamp(%s)", POSTGRESQL_MAX_SUPPORTED_TIMESTAMP_PRECISION), longTimestampWriteFunction()); + return WriteMapping.objectMapping(format("timestamp(%s)", POSTGRESQL_MAX_SUPPORTED_TIMESTAMP_PRECISION), + longTimestampWriteFunction()); } } if (type instanceof TimestampWithTimeZoneType timestampWithTimeZoneType) { @@ -305,7 +310,8 @@ public WriteMapping toWriteMapping(ConnectorSession session, Type type) if (type instanceof ArrayType arrayType && getArrayMapping(session) == AS_ARRAY) { Type elementType = arrayType.getElementType(); String elementDataType = toWriteMapping(session, elementType).getDataType(); - return WriteMapping.objectMapping(elementDataType + "[]", arrayWriteFunction(session, elementType, getArrayElementPgTypeName(session, this, elementType))); + return WriteMapping.objectMapping(elementDataType + "[]", + arrayWriteFunction(session, elementType, getArrayElementPgTypeName(session, this, elementType))); } throw new TrinoException(NOT_SUPPORTED, "Unsupported column type: " + type.getDisplayName()); } @@ -315,10 +321,12 @@ public ColumnDataType getColumnDataType(ConnectorSession session, JdbcTypeHandle { return Optional.ofNullable(toColumnMappingInternal(session, Optional.empty(), typeHandle)) .map(AdbColumnMapping::columnDataType) - .orElseThrow(() -> new IllegalArgumentException("Failed to get column data type for type: " + typeHandle)); + .orElseThrow( + () -> new IllegalArgumentException("Failed to get column data type for type: " + typeHandle)); } - private AdbColumnMapping toColumnMappingInternal(ConnectorSession session, Optional connection, JdbcTypeHandle typeHandle) + private AdbColumnMapping toColumnMappingInternal(ConnectorSession session, Optional connection, + JdbcTypeHandle typeHandle) { String jdbcTypeName = typeHandle.jdbcTypeName() .orElseThrow(() -> new TrinoException(JDBC_ERROR, "Type name is missing: " + typeHandle)); @@ -360,14 +368,17 @@ private AdbColumnMapping toColumnMappingInternal(ConnectorSession session, Optio if (getDecimalRounding(session) == ALLOW_OVERFLOW) { if (columnSize == PRECISION_OF_UNSPECIFIED_DECIMAL) { // decimal type with unspecified scale - up to 131072 digits before the decimal point; up to 16383 digits after the decimal point) - DecimalType decimalType = createDecimalType(Decimals.MAX_PRECISION, getDecimalDefaultScale(session)); - return new AdbColumnMapping(decimalColumnMapping(decimalType, getDecimalRoundingMode(session)), new UnknownDataType()); + DecimalType decimalType = + createDecimalType(Decimals.MAX_PRECISION, getDecimalDefaultScale(session)); + return new AdbColumnMapping(decimalColumnMapping(decimalType, getDecimalRoundingMode(session)), + new UnknownDataType()); } precision = columnSize; if (precision > Decimals.MAX_PRECISION) { int scale = min(decimalDigits, getDecimalDefaultScale(session)); DecimalType decimalType = createDecimalType(Decimals.MAX_PRECISION, scale); - return new AdbColumnMapping(decimalColumnMapping(decimalType, getDecimalRoundingMode(session)), new UnknownDataType()); + return new AdbColumnMapping(decimalColumnMapping(decimalType, getDecimalRoundingMode(session)), + new UnknownDataType()); } } precision = columnSize + max(-decimalDigits, 0); @@ -383,13 +394,15 @@ private AdbColumnMapping toColumnMappingInternal(ConnectorSession session, Optio } case Types.CHAR: ColumnMapping charColumnMapping = charColumnMapping(typeHandle.requiredColumnSize()); - return new AdbColumnMapping(charColumnMapping, new CharDataType((CharType) charColumnMapping.getType())); + return new AdbColumnMapping(charColumnMapping, + new CharDataType((CharType) charColumnMapping.getType())); case Types.VARCHAR: if (!jdbcTypeName.equals("varchar")) { return new AdbColumnMapping(enumColumnMapping(session, jdbcTypeName), new UnknownDataType()); } ColumnMapping varcharColumnMapping = varcharColumnMapping(typeHandle.requiredColumnSize()); - return new AdbColumnMapping(varcharColumnMapping, new VarcharDataType((VarcharType) varcharColumnMapping.getType())); + return new AdbColumnMapping(varcharColumnMapping, + new VarcharDataType((VarcharType) varcharColumnMapping.getType())); case Types.BINARY: ColumnMapping columnMapping = varbinaryColumnMapping(); if (jdbcTypeName.equals("bytea")) { @@ -400,16 +413,19 @@ private AdbColumnMapping toColumnMappingInternal(ConnectorSession session, Optio return new AdbColumnMapping(dateColumnMappingUsingLocalDate(), new DateDataType()); case Types.TIME: int requiredDecimalDigits = typeHandle.requiredDecimalDigits(); - return new AdbColumnMapping(timeColumnMapping(requiredDecimalDigits), new TimeDataType(requiredDecimalDigits)); + return new AdbColumnMapping(timeColumnMapping(requiredDecimalDigits), + new TimeDataType(requiredDecimalDigits)); case Types.TIMESTAMP: TimestampType timestampType = createTimestampType(typeHandle.requiredDecimalDigits()); return new AdbColumnMapping( ColumnMapping.longMapping( - timestampType, timestampReadFunction(timestampType), DataTypeMapperImpl::shortTimestampWriteFunction), + timestampType, timestampReadFunction(timestampType), + DataTypeMapperImpl::shortTimestampWriteFunction), new TimestampWithoutTimeZoneDataType(timestampType)); case Types.ARRAY: if (connection.isPresent()) { - Optional arrayColumnMapping = arrayToTrinoType(session, connection.get(), typeHandle); + Optional arrayColumnMapping = + arrayToTrinoType(session, connection.get(), typeHandle); if (arrayColumnMapping.isPresent()) { return new AdbColumnMapping(arrayColumnMapping.get(), new UnknownDataType()); } @@ -514,7 +530,8 @@ public List getColumnDataTypes(ConnectorSession session, JdbcOut private Optional getForcedMappingToVarchar(JdbcTypeHandle typeHandle) { - if (typeHandle.jdbcTypeName().isPresent() && jdbcTypesMappedToVarchar.contains(typeHandle.jdbcTypeName().get())) { + if (typeHandle.jdbcTypeName().isPresent() && + jdbcTypesMappedToVarchar.contains(typeHandle.jdbcTypeName().get())) { return mapToUnboundedVarchar(typeHandle); } return Optional.empty(); @@ -529,7 +546,8 @@ protected static Optional mapToUnboundedVarchar(JdbcTypeHandle ty (statement, index, value) -> { throw new TrinoException( NOT_SUPPORTED, - "Underlying type that is mapped to VARCHAR is not supported for INSERT: " + typeHandle.jdbcTypeName().get()); + "Underlying type that is mapped to VARCHAR is not supported for INSERT: " + + typeHandle.jdbcTypeName().get()); }, DISABLE_PUSHDOWN)); } @@ -563,7 +581,8 @@ public Slice readSlice(ResultSet resultSet, int columnIndex) private static LongWriteFunction timeWriteFunction(int precision) { - checkArgument(precision <= 6, "Unsupported precision: %s", precision); // PostgreSQL limit but also assumption within this method + checkArgument(precision <= 6, "Unsupported precision: %s", + precision); // PostgreSQL limit but also assumption within this method String bindExpression = format("CAST(? AS time(%s))", precision); return new LongWriteFunction() { @@ -591,7 +610,9 @@ public void set(PreparedStatement statement, int index, long picosOfDay) private static SliceWriteFunction typedVarcharWriteFunction(String jdbcTypeName) { requireNonNull(jdbcTypeName, "jdbcTypeName is null"); - String quotedJdbcTypeName = jdbcTypeName.startsWith("\"") && jdbcTypeName.endsWith("\"") ? jdbcTypeName : "\"%s\"".formatted(jdbcTypeName.replace("\"", "\"\"")); + String quotedJdbcTypeName = + jdbcTypeName.startsWith("\"") && jdbcTypeName.endsWith("\"") ? jdbcTypeName : "\"%s\"".formatted( + jdbcTypeName.replace("\"", "\"\"")); String bindExpression = format("CAST(? AS %s)", quotedJdbcTypeName); return new SliceWriteFunction() @@ -631,7 +652,8 @@ private ColumnMapping jsonColumnMapping() private static AdbColumnMapping timestampWithTimeZoneColumnMapping(int precision) { // Adb supports timestamptz precision up to microseconds - checkArgument(precision <= POSTGRESQL_MAX_SUPPORTED_TIMESTAMP_PRECISION, "unsupported precision value %s", precision); + checkArgument(precision <= POSTGRESQL_MAX_SUPPORTED_TIMESTAMP_PRECISION, "unsupported precision value %s", + precision); TimestampWithTimeZoneType trinoType = createTimestampWithTimeZoneType(precision); if (precision <= TimestampWithTimeZoneType.MAX_SHORT_PRECISION) { return new AdbColumnMapping(ColumnMapping.longMapping( @@ -686,8 +708,11 @@ private static ObjectWriteFunction longTimestampWithTimeZoneWriteFunction() (statement, index, value) -> { // Adb does not store zone information in "timestamp with time zone" data type long epochSeconds = floorDiv(value.getEpochMillis(), MILLISECONDS_PER_SECOND); - long nanosOfSecond = (long) floorMod(value.getEpochMillis(), MILLISECONDS_PER_SECOND) * NANOSECONDS_PER_MILLISECOND + value.getPicosOfMilli() / PICOSECONDS_PER_NANOSECOND; - statement.setObject(index, OffsetDateTime.ofInstant(Instant.ofEpochSecond(epochSeconds, nanosOfSecond), UTC_KEY.getZoneId())); + long nanosOfSecond = (long) floorMod(value.getEpochMillis(), MILLISECONDS_PER_SECOND) * + NANOSECONDS_PER_MILLISECOND + value.getPicosOfMilli() / PICOSECONDS_PER_NANOSECOND; + statement.setObject(index, + OffsetDateTime.ofInstant(Instant.ofEpochSecond(epochSeconds, nanosOfSecond), + UTC_KEY.getZoneId())); }); } @@ -710,7 +735,8 @@ private static ColumnMapping charColumnMapping(int charLength) private static ColumnMapping varcharColumnMapping(int varcharLength) { - VarcharType varcharType = varcharLength <= 2147483646 ? VarcharType.createVarcharType(varcharLength) : VarcharType.createUnboundedVarcharType(); + VarcharType varcharType = varcharLength <= 2147483646 ? VarcharType.createVarcharType( + varcharLength) : VarcharType.createUnboundedVarcharType(); return ColumnMapping.sliceMapping( varcharType, varcharReadFunction(varcharType), @@ -722,7 +748,8 @@ private static ColumnMapping enumColumnMapping(ConnectorSession session, String { //todo implement AdbAdvancedPushdownSessionProperties.isPushdownEnums(session); boolean pushdownEnums = false; - PredicatePushdownController pushdownController = pushdownEnums ? ADB_STRING_COLLATION_AWARE_PUSHDOWN : DISABLE_PUSHDOWN; + PredicatePushdownController pushdownController = + pushdownEnums ? ADB_STRING_COLLATION_AWARE_PUSHDOWN : DISABLE_PUSHDOWN; return ColumnMapping.sliceMapping( VARCHAR, (resultSet, columnIndex) -> utf8Slice(resultSet.getString(columnIndex)), @@ -773,7 +800,8 @@ private static ObjectWriteFunction longTimestampWriteFunction() }); } - private Optional arrayToTrinoType(ConnectorSession session, Connection connection, JdbcTypeHandle typeHandle) + private Optional arrayToTrinoType(ConnectorSession session, Connection connection, + JdbcTypeHandle typeHandle) { checkArgument(typeHandle.jdbcType() == Types.ARRAY, "Not array type"); AdbPluginConfig.ArrayMapping arrayMapping = getArrayMapping(session); @@ -782,7 +810,8 @@ private Optional arrayToTrinoType(ConnectorSession session, Conne } JdbcTypeHandle baseElementTypeHandle = getArrayElementTypeHandle(connection, typeHandle); String baseElementTypeName = baseElementTypeHandle.jdbcTypeName() - .orElseThrow(() -> new TrinoException(JDBC_ERROR, "Element type name is missing: " + baseElementTypeHandle)); + .orElseThrow( + () -> new TrinoException(JDBC_ERROR, "Element type name is missing: " + baseElementTypeHandle)); if (baseElementTypeHandle.jdbcType() == Types.BINARY) { // adb jdbc driver doesn't currently support array of varbinary (bytea[]) return Optional.empty(); @@ -796,12 +825,14 @@ private Optional arrayToTrinoType(ConnectorSession session, Conne return baseElementMapping .map(elementMapping -> { ArrayType trinoArrayType = new ArrayType(elementMapping.getType()); - ColumnMapping arrayColumnMapping = arrayColumnMapping(session, trinoArrayType, elementMapping, baseElementTypeName); + ColumnMapping arrayColumnMapping = + arrayColumnMapping(session, trinoArrayType, elementMapping, baseElementTypeName); int arrayDimensions = typeHandle.arrayDimensions().get(); for (int i = 1; i < arrayDimensions; i++) { trinoArrayType = new ArrayType(trinoArrayType); - arrayColumnMapping = arrayColumnMapping(session, trinoArrayType, arrayColumnMapping, baseElementTypeName); + arrayColumnMapping = arrayColumnMapping(session, trinoArrayType, arrayColumnMapping, + baseElementTypeName); } return arrayColumnMapping; }); @@ -813,7 +844,8 @@ private Optional arrayToTrinoType(ConnectorSession session, Conne throw new IllegalStateException("Unsupported array mapping type: " + arrayMapping); } - private static ColumnMapping arrayColumnMapping(ConnectorSession session, ArrayType arrayType, ColumnMapping arrayElementMapping, String baseElementJdbcTypeName) + private static ColumnMapping arrayColumnMapping(ConnectorSession session, ArrayType arrayType, + ColumnMapping arrayElementMapping, String baseElementJdbcTypeName) { return ColumnMapping.objectMapping( arrayType, @@ -821,10 +853,12 @@ private static ColumnMapping arrayColumnMapping(ConnectorSession session, ArrayT arrayWriteFunction(session, arrayType.getElementType(), baseElementJdbcTypeName)); } - private static ObjectWriteFunction arrayWriteFunction(ConnectorSession session, Type elementType, String baseElementJdbcTypeName) + private static ObjectWriteFunction arrayWriteFunction(ConnectorSession session, Type elementType, + String baseElementJdbcTypeName) { return ObjectWriteFunction.of(Block.class, (statement, index, block) -> { - Array jdbcArray = statement.getConnection().createArrayOf(baseElementJdbcTypeName, getJdbcObjectArray(session, elementType, block)); + Array jdbcArray = statement.getConnection() + .createArrayOf(baseElementJdbcTypeName, getJdbcObjectArray(session, elementType, block)); statement.setArray(index, jdbcArray); }); } @@ -898,19 +932,29 @@ private static ObjectReadFunction arrayReadFunction(Type elementType, ReadFuncti builder.appendNull(); } else if (elementType.getJavaType() == boolean.class) { - elementType.writeBoolean(builder, ((BooleanReadFunction) elementReadFunction).readBoolean(arrayAsResultSet, ARRAY_RESULT_SET_VALUE_COLUMN)); + elementType.writeBoolean(builder, + ((BooleanReadFunction) elementReadFunction).readBoolean(arrayAsResultSet, + ARRAY_RESULT_SET_VALUE_COLUMN)); } else if (elementType.getJavaType() == long.class) { - elementType.writeLong(builder, ((LongReadFunction) elementReadFunction).readLong(arrayAsResultSet, ARRAY_RESULT_SET_VALUE_COLUMN)); + elementType.writeLong(builder, + ((LongReadFunction) elementReadFunction).readLong(arrayAsResultSet, + ARRAY_RESULT_SET_VALUE_COLUMN)); } else if (elementType.getJavaType() == double.class) { - elementType.writeDouble(builder, ((DoubleReadFunction) elementReadFunction).readDouble(arrayAsResultSet, ARRAY_RESULT_SET_VALUE_COLUMN)); + elementType.writeDouble(builder, + ((DoubleReadFunction) elementReadFunction).readDouble(arrayAsResultSet, + ARRAY_RESULT_SET_VALUE_COLUMN)); } else if (elementType.getJavaType() == Slice.class) { - elementType.writeSlice(builder, ((SliceReadFunction) elementReadFunction).readSlice(arrayAsResultSet, ARRAY_RESULT_SET_VALUE_COLUMN)); + elementType.writeSlice(builder, + ((SliceReadFunction) elementReadFunction).readSlice(arrayAsResultSet, + ARRAY_RESULT_SET_VALUE_COLUMN)); } else { - elementType.writeObject(builder, ((ObjectReadFunction) elementReadFunction).readObject(arrayAsResultSet, ARRAY_RESULT_SET_VALUE_COLUMN)); + elementType.writeObject(builder, + ((ObjectReadFunction) elementReadFunction).readObject(arrayAsResultSet, + ARRAY_RESULT_SET_VALUE_COLUMN)); } } } diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/decode/csv/CsvRowDecoder.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/decode/csv/CsvRowDecoder.java index cfdfaf6eafce8..c52bf913b6c9f 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/decode/csv/CsvRowDecoder.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/decode/csv/CsvRowDecoder.java @@ -206,7 +206,8 @@ private static BiFunction decodeBytes() return new ColumnValue(value, value.getRetainedSize()); } catch (SQLException e) { - throw new IllegalArgumentException(format(DECODE_VALUE_ERROR_MSG_TEMPLATE, data, ConnectorDataType.BYTEA)); + throw new IllegalArgumentException( + format(DECODE_VALUE_ERROR_MSG_TEMPLATE, data, ConnectorDataType.BYTEA)); } }; } @@ -295,7 +296,8 @@ private static BiFunction decodeDouble() private static BiFunction decodeDate() { - return (_, data) -> new ColumnValue(LocalDate.parse(data, DATE_TYPE_FORMATTER).toEpochDay(), SizeOf.LONG_INSTANCE_SIZE); + return (_, data) -> new ColumnValue(LocalDate.parse(data, DATE_TYPE_FORMATTER).toEpochDay(), + SizeOf.LONG_INSTANCE_SIZE); } private static BiFunction decodeTime() @@ -324,14 +326,16 @@ private static BiFunction decodeTimestampSh return (_, data) -> { LongTimestampWithTimeZone timestampValue = decodeTimestampWithTimeZoneLong(data); long millisUtc = timestampValue.getEpochMillis(); - long value = DateTimeEncoding.packDateTimeWithZone(millisUtc, TimeZoneKey.getTimeZoneKey(timestampValue.getTimeZoneKey())); + long value = DateTimeEncoding.packDateTimeWithZone(millisUtc, + TimeZoneKey.getTimeZoneKey(timestampValue.getTimeZoneKey())); return new ColumnValue(value, SizeOf.sizeOf(value)); }; } private static BiFunction decodeTimestampLongWithTimeZone() { - return (_, data) -> new ColumnValue(decodeTimestampWithTimeZoneLong(data), LongTimestampWithTimeZone.INSTANCE_SIZE); + return (_, data) -> new ColumnValue(decodeTimestampWithTimeZoneLong(data), + LongTimestampWithTimeZone.INSTANCE_SIZE); } private static BiFunction decodeTimestampWithoutTimeZone() @@ -353,7 +357,8 @@ else if ("f".equals(data)) { return false; } else { - throw new IllegalArgumentException(format(DECODE_VALUE_ERROR_MSG_TEMPLATE, data, ConnectorDataType.BOOLEAN)); + throw new IllegalArgumentException( + format(DECODE_VALUE_ERROR_MSG_TEMPLATE, data, ConnectorDataType.BOOLEAN)); } } diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/encode/AbstractRowEncoder.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/encode/AbstractRowEncoder.java index 857f84d0b226b..ba273b25e1e71 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/encode/AbstractRowEncoder.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/encode/AbstractRowEncoder.java @@ -82,26 +82,46 @@ protected AbstractRowEncoder(ConnectorSession session, List colu this.columnDataTypes = ImmutableList.copyOf(columnDataTypes); this.currentColumnIndex = 0; map = Map.ofEntries( - Map.entry(ConnectorDataType.BOOLEAN, (encoder, pageBlock) -> encoder.appendBoolean(BOOLEAN.getBoolean(pageBlock.block(), pageBlock.position()))), - Map.entry(ConnectorDataType.MONEY, (encoder, pageBlock) -> encoder.appendString(((MoneyDataType) pageBlock.columnDataType()).getVarcharType().getSlice(pageBlock.block(), pageBlock.position()).toStringUtf8())), - Map.entry(ConnectorDataType.UUID, (encoder, pageBlock) -> encoder.appendString(UuidType.trinoUuidToJavaUuid(UuidType.UUID.getSlice(pageBlock.block(), pageBlock.position())).toString())), - Map.entry(ConnectorDataType.JSONB, (encoder, pageBlock) -> encoder.appendString(VarcharType.VARCHAR.getSlice(pageBlock.block(), pageBlock.position()).toStringUtf8())), - Map.entry(ConnectorDataType.BIT, (encoder, pageBlock) -> encoder.appendString(BooleanType.BOOLEAN.getBoolean(pageBlock.block(), pageBlock.position()) ? "1" : "0")), - Map.entry(ConnectorDataType.BIGINT, (encoder, pageBlock) -> encoder.appendLong(BIGINT.getLong(pageBlock.block(), pageBlock.position()))), - Map.entry(ConnectorDataType.BYTEA, (encoder, pageBlock) -> encoder.appendBytes(VarbinaryType.VARBINARY.getSlice(pageBlock.block(), pageBlock.position()).getBytes())), - Map.entry(ConnectorDataType.CHAR, (encoder, pageBlock) -> encoder.appendString(((CharDataType) pageBlock.columnDataType()).getCharType().getSlice(pageBlock.block(), pageBlock.position()).toStringUtf8())), - Map.entry(ConnectorDataType.VARCHAR, (encoder, pageBlock) -> encoder.appendString(((VarcharDataType) pageBlock.columnDataType()).getVarcharType().getSlice(pageBlock.block(), pageBlock.position()).toStringUtf8())), + Map.entry(ConnectorDataType.BOOLEAN, (encoder, pageBlock) -> encoder.appendBoolean( + BOOLEAN.getBoolean(pageBlock.block(), pageBlock.position()))), + Map.entry(ConnectorDataType.MONEY, (encoder, pageBlock) -> encoder.appendString( + ((MoneyDataType) pageBlock.columnDataType()).getVarcharType() + .getSlice(pageBlock.block(), pageBlock.position()).toStringUtf8())), + Map.entry(ConnectorDataType.UUID, (encoder, pageBlock) -> encoder.appendString( + UuidType.trinoUuidToJavaUuid(UuidType.UUID.getSlice(pageBlock.block(), pageBlock.position())) + .toString())), + Map.entry(ConnectorDataType.JSONB, (encoder, pageBlock) -> encoder.appendString( + VarcharType.VARCHAR.getSlice(pageBlock.block(), pageBlock.position()).toStringUtf8())), + Map.entry(ConnectorDataType.BIT, (encoder, pageBlock) -> encoder.appendString( + BooleanType.BOOLEAN.getBoolean(pageBlock.block(), pageBlock.position()) ? "1" : "0")), + Map.entry(ConnectorDataType.BIGINT, (encoder, pageBlock) -> encoder.appendLong( + BIGINT.getLong(pageBlock.block(), pageBlock.position()))), + Map.entry(ConnectorDataType.BYTEA, (encoder, pageBlock) -> encoder.appendBytes( + VarbinaryType.VARBINARY.getSlice(pageBlock.block(), pageBlock.position()).getBytes())), + Map.entry(ConnectorDataType.CHAR, (encoder, pageBlock) -> encoder.appendString( + ((CharDataType) pageBlock.columnDataType()).getCharType() + .getSlice(pageBlock.block(), pageBlock.position()).toStringUtf8())), + Map.entry(ConnectorDataType.VARCHAR, (encoder, pageBlock) -> encoder.appendString( + ((VarcharDataType) pageBlock.columnDataType()).getVarcharType() + .getSlice(pageBlock.block(), pageBlock.position()).toStringUtf8())), Map.entry(ConnectorDataType.DECIMAL_SHORT, AbstractRowEncoder::encodeDecimalShort), Map.entry(ConnectorDataType.DECIMAL_LONG, AbstractRowEncoder::encodeDecimalLong), - Map.entry(ConnectorDataType.INTEGER, (encoder, pageBlock) -> encoder.appendInt(INTEGER.getInt(pageBlock.block(), pageBlock.position()))), + Map.entry(ConnectorDataType.INTEGER, (encoder, pageBlock) -> encoder.appendInt( + INTEGER.getInt(pageBlock.block(), pageBlock.position()))), Map.entry(ConnectorDataType.SMALLINT, AbstractRowEncoder::encodeSmallint), - Map.entry(ConnectorDataType.REAL, (encoder, pageBlock) -> encoder.appendFloat(intBitsToFloat(toIntExact(RealType.REAL.getLong(pageBlock.block(), pageBlock.position()))))), - Map.entry(ConnectorDataType.DOUBLE_PRECISION, (encoder, pageBlock) -> encoder.appendDouble(DoubleType.DOUBLE.getDouble(pageBlock.block(), pageBlock.position()))), - Map.entry(ConnectorDataType.DATE, (encoder, pageBlock) -> encoder.appendDate(LocalDate.ofEpochDay(DateType.DATE.getLong(pageBlock.block(), pageBlock.position())))), + Map.entry(ConnectorDataType.REAL, (encoder, pageBlock) -> encoder.appendFloat( + intBitsToFloat(toIntExact(RealType.REAL.getLong(pageBlock.block(), pageBlock.position()))))), + Map.entry(ConnectorDataType.DOUBLE_PRECISION, (encoder, pageBlock) -> encoder.appendDouble( + DoubleType.DOUBLE.getDouble(pageBlock.block(), pageBlock.position()))), + Map.entry(ConnectorDataType.DATE, (encoder, pageBlock) -> encoder.appendDate( + LocalDate.ofEpochDay(DateType.DATE.getLong(pageBlock.block(), pageBlock.position())))), Map.entry(ConnectorDataType.TIME, AbstractRowEncoder::encodeTime), - Map.entry(ConnectorDataType.TIMESTAMP_SHORT_WITH_TIME_ZONE, AbstractRowEncoder::encodeTimestampShortWithTimeZone), - Map.entry(ConnectorDataType.TIMESTAMP_LONG_WITH_TIME_ZONE, AbstractRowEncoder::encodeTimestampLongWithTimeZone), - Map.entry(ConnectorDataType.TIMESTAMP_WITHOUT_TIME_ZONE, AbstractRowEncoder::encodeTimestampWithoutTimeZone)); + Map.entry(ConnectorDataType.TIMESTAMP_SHORT_WITH_TIME_ZONE, + AbstractRowEncoder::encodeTimestampShortWithTimeZone), + Map.entry(ConnectorDataType.TIMESTAMP_LONG_WITH_TIME_ZONE, + AbstractRowEncoder::encodeTimestampLongWithTimeZone), + Map.entry(ConnectorDataType.TIMESTAMP_WITHOUT_TIME_ZONE, + AbstractRowEncoder::encodeTimestampWithoutTimeZone)); } @Override @@ -137,9 +157,11 @@ private static void encodeTimestampShortWithTimeZone(AbstractRowEncoder encoder, private static void encodeTimestampLongWithTimeZone(AbstractRowEncoder encoder, EncoderMetadata pageBlock) { TimestampLongWithTimeZoneDataType dataType = (TimestampLongWithTimeZoneDataType) pageBlock.columnDataType(); - LongTimestampWithTimeZone value = (LongTimestampWithTimeZone) dataType.getTrinoType().getObject(pageBlock.block(), pageBlock.position()); + LongTimestampWithTimeZone value = + (LongTimestampWithTimeZone) dataType.getTrinoType().getObject(pageBlock.block(), pageBlock.position()); long epochSeconds = Math.floorDiv(value.getEpochMillis(), 1000); - long nanosOfSecond = (long) Math.floorMod(value.getEpochMillis(), 1000) * 1000000L + (long) (value.getPicosOfMilli() / 1000); + long nanosOfSecond = + (long) Math.floorMod(value.getEpochMillis(), 1000) * 1000000L + (long) (value.getPicosOfMilli() / 1000); Instant instant = Instant.ofEpochSecond(epochSeconds, nanosOfSecond); encoder.appendDateTimeWithTimeZone(OffsetDateTime.ofInstant(instant, TimeZoneKey.UTC_KEY.getZoneId())); } @@ -156,15 +178,18 @@ private static void encodeDecimalShort(AbstractRowEncoder encoder, EncoderMetada { DecimalType decimalType = ((DecimalShortDataType) pageBlock.columnDataType()).getDecimalType(); BigInteger unscaledValue = BigInteger.valueOf(decimalType.getLong(pageBlock.block(), pageBlock.position())); - BigDecimal value = new BigDecimal(unscaledValue, decimalType.getScale(), new MathContext(decimalType.getPrecision())); + BigDecimal value = + new BigDecimal(unscaledValue, decimalType.getScale(), new MathContext(decimalType.getPrecision())); encoder.appendBigDecimal(value); } private static void encodeDecimalLong(AbstractRowEncoder encoder, EncoderMetadata pageBlock) { DecimalType decimalType = ((DecimalShortDataType) pageBlock.columnDataType()).getDecimalType(); - BigInteger unscaledValue = ((Int128) decimalType.getObject(pageBlock.block(), pageBlock.position())).toBigInteger(); - BigDecimal value = new BigDecimal(unscaledValue, decimalType.getScale(), new MathContext(decimalType.getPrecision())); + BigInteger unscaledValue = + ((Int128) decimalType.getObject(pageBlock.block(), pageBlock.position())).toBigInteger(); + BigDecimal value = + new BigDecimal(unscaledValue, decimalType.getScale(), new MathContext(decimalType.getPrecision())); encoder.appendBigDecimal(value); } diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/encode/csv/CsvFormatConfig.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/encode/csv/CsvFormatConfig.java index 59a72ee2e6a2c..bb8438664ab41 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/encode/csv/CsvFormatConfig.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/encode/csv/CsvFormatConfig.java @@ -16,19 +16,15 @@ import io.trino.plugin.adb.connector.encode.DataFormat; import io.trino.plugin.adb.connector.encode.DataFormatConfig; +import java.util.Optional; + public class CsvFormatConfig implements DataFormatConfig { private char delimiter = '|'; - private String nullValue; + private Optional nullValue = Optional.empty(); private String encoding = "UTF-8"; - public CsvFormatConfig() - { - //default value, otherwise checkstyle plugin will raise error - nullValue = null; - } - public static CsvFormatConfig create() { return new CsvFormatConfig(); @@ -42,7 +38,7 @@ public CsvFormatConfig delimiter(char delimiter) public CsvFormatConfig nullValue(String nullValue) { - this.nullValue = nullValue; + this.nullValue = Optional.ofNullable(nullValue); return this; } @@ -51,7 +47,7 @@ public char getDelimiter() return delimiter; } - public String getNullValue() + public Optional getNullValue() { return nullValue; } diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/encode/csv/CsvRowEncoder.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/encode/csv/CsvRowEncoder.java index 046d4215b3627..de110c9771cfb 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/encode/csv/CsvRowEncoder.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/encode/csv/CsvRowEncoder.java @@ -62,7 +62,7 @@ public CsvRowEncoder(ConnectorSession session, List columnDataTy @Override protected void appendNullValue() { - row[currentColumnIndex] = encoderConfig.getNullValue(); + row[currentColumnIndex] = encoderConfig.getNullValue().orElse(null); } @Override diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteBooleanConstant.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteBooleanConstant.java new file mode 100644 index 0000000000000..c1f142fb1d198 --- /dev/null +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteBooleanConstant.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.adb.connector.expression; + +import com.google.common.collect.ImmutableList; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.adb.connector.AdbPushdownSessionProperties; +import io.trino.plugin.base.expression.ConnectorExpressionPatterns; +import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.plugin.jdbc.QueryParameter; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.expression.Constant; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.Type; + +import java.util.Optional; + +public class AdbRewriteBooleanConstant + implements ConnectorExpressionRule +{ + private static final Pattern PATTERN = ConnectorExpressionPatterns.constant() + .with(ConnectorExpressionPatterns.type().matching(type -> type == BooleanType.BOOLEAN)); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public boolean isEnabled(ConnectorSession session) + { + return AdbPushdownSessionProperties.isPushdownLiterals(session); + } + + @Override + public Optional rewrite(Constant constant, Captures captures, + RewriteContext context) + { + Type type = constant.getType(); + Object value = constant.getValue(); + return value == null ? Optional.empty() : Optional.of( + new ParameterizedExpression("?", ImmutableList.of(new QueryParameter(type, Optional.of(value))))); + } +} diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteCast.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteCast.java new file mode 100644 index 0000000000000..066141e43cfdf --- /dev/null +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteCast.java @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.adb.connector.expression; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.adb.connector.AdbPushdownSessionProperties; +import io.trino.plugin.adb.connector.datatype.mapper.DataTypeMapper; +import io.trino.plugin.base.expression.ConnectorExpressionPatterns; +import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.plugin.jdbc.QueryParameter; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; +import io.trino.spi.StandardErrorCode; +import io.trino.spi.TrinoException; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.expression.Call; +import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.expression.StandardFunctions; + +import java.util.Optional; + +public class AdbRewriteCast + implements ConnectorExpressionRule +{ + private static final Capture ARG = Capture.newCapture(); + private static final Pattern PATTERN = ConnectorExpressionPatterns.call() + .with(ConnectorExpressionPatterns.functionName().equalTo(StandardFunctions.CAST_FUNCTION_NAME)) + .with(ConnectorExpressionPatterns.argument(0) + .matching(ConnectorExpressionPatterns.expression().capturedAs(ARG))); + private final DataTypeMapper dataTypeMapper; + + public AdbRewriteCast(DataTypeMapper dataTypeMapper) + { + this.dataTypeMapper = dataTypeMapper; + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public boolean isEnabled(ConnectorSession session) + { + return AdbPushdownSessionProperties.isPushdownFunctionCast(session); + } + + @Override + public Optional rewrite(Call expression, Captures captures, + RewriteContext context) + { + try { + String returnType = dataTypeMapper.toWriteMapping(context.getSession(), expression.getType()).getDataType(); + Optional arg = context.defaultRewrite(captures.get(ARG)); + if (arg.isEmpty()) { + return Optional.empty(); + } + else { + Builder parameters = ImmutableList.builder(); + parameters.addAll(arg.get().parameters()); + return Optional.of( + new ParameterizedExpression(String.format("CAST(%s AS %s)", arg.get().expression(), returnType), + parameters.build())); + } + } + catch (TrinoException e) { + if (e.getErrorCode() == StandardErrorCode.NOT_SUPPORTED.toErrorCode()) { + return Optional.empty(); + } + else { + throw e; + } + } + } +} diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteCharConstant.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteCharConstant.java new file mode 100644 index 0000000000000..9f700fee2ff5d --- /dev/null +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteCharConstant.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.adb.connector.expression; + +import com.google.common.collect.ImmutableList; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.adb.connector.AdbPushdownSessionProperties; +import io.trino.plugin.base.expression.ConnectorExpressionPatterns; +import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.plugin.jdbc.QueryParameter; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.expression.Constant; +import io.trino.spi.type.CharType; +import io.trino.spi.type.Type; + +import java.util.Optional; + +public class AdbRewriteCharConstant + implements ConnectorExpressionRule +{ + private static final Pattern PATTERN = ConnectorExpressionPatterns.constant() + .with(ConnectorExpressionPatterns.type().matching(type -> type instanceof CharType)); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public boolean isEnabled(ConnectorSession session) + { + return AdbPushdownSessionProperties.isPushdownLiterals(session); + } + + @Override + public Optional rewrite(Constant constant, Captures captures, + RewriteContext context) + { + Type type = constant.getType(); + Object value = constant.getValue(); + return value == null ? Optional.empty() : + Optional.of(new ParameterizedExpression("?", + ImmutableList.of(new QueryParameter(type, Optional.of(value))))); + } +} diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteDatetimeArithmetics.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteDatetimeArithmetics.java new file mode 100644 index 0000000000000..086e5850a6bae --- /dev/null +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteDatetimeArithmetics.java @@ -0,0 +1,109 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.adb.connector.expression; + +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.adb.connector.AdbPushdownSessionProperties; +import io.trino.plugin.base.expression.ConnectorExpressionPatterns; +import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.expression.Call; +import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.expression.Constant; +import io.trino.spi.expression.StandardFunctions; +import io.trino.spi.type.DateType; +import io.trino.spi.type.TimeType; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.TimestampWithTimeZoneType; +import io.trino.spi.type.Type; + +import java.util.Optional; + +public class AdbRewriteDatetimeArithmetics + implements ConnectorExpressionRule +{ + private static final String TYPE_NAME_DAY_SECOND = "interval day to second"; + private static final String TYPE_NAME_YEAR_MONTH = "interval year to month"; + private static final Capture ARG0 = Capture.newCapture(); + private static final Capture ARG1 = Capture.newCapture(); + private static final Pattern PATTERN = ConnectorExpressionPatterns.call() + .with(ConnectorExpressionPatterns.functionName() + .matching(n -> StandardFunctions.ADD_FUNCTION_NAME.equals(n) || + StandardFunctions.SUBTRACT_FUNCTION_NAME.equals(n))) + .with(ConnectorExpressionPatterns.argument(0) + .matching(ConnectorExpressionPatterns.expression() + .capturedAs(ARG0) + .with(ConnectorExpressionPatterns.type() + .matching(type -> type == DateType.DATE + || type instanceof TimeType + || type instanceof TimestampType + || type instanceof TimestampWithTimeZoneType)))) + .with(ConnectorExpressionPatterns.argument(1) + .matching(ConnectorExpressionPatterns.constant() + .matching(c -> isInterval(c.getType())).capturedAs(ARG1))); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public boolean isEnabled(ConnectorSession session) + { + return AdbPushdownSessionProperties.isPushdownDatetimeArithmetics(session); + } + + @Override + public Optional rewrite(Call expression, Captures captures, + RewriteContext context) + { + Constant arg1 = captures.get(ARG1); + if (arg1.getValue() == null) { + return Optional.empty(); + } + else { + Optional arg0 = context.defaultRewrite(captures.get(ARG0)); + if (arg0.isEmpty()) { + return Optional.empty(); + } + else { + String arg1Caption = String.format("interval '%d %s'", (Long) arg1.getValue(), + isDaySecondInterval(arg1.getType()) ? "milliseconds" : "months"); + String operator = expression.getFunctionName() == StandardFunctions.ADD_FUNCTION_NAME ? "+" : "-"; + return Optional.of(new ParameterizedExpression( + String.format("%s %s %s", arg0.get().expression(), operator, arg1Caption), + arg0.get().parameters())); + } + } + } + + private static boolean isInterval(Type type) + { + return isDaySecondInterval(type) || isYearMonthInterval(type); + } + + private static boolean isDaySecondInterval(Type type) + { + return TYPE_NAME_DAY_SECOND.equals(type.getDisplayName()); + } + + private static boolean isYearMonthInterval(Type type) + { + return TYPE_NAME_YEAR_MONTH.equals(type.getDisplayName()); + } +} diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteInexactNumericConstant.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteInexactNumericConstant.java new file mode 100644 index 0000000000000..590ffc1644053 --- /dev/null +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteInexactNumericConstant.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.adb.connector.expression; + +import com.google.common.collect.ImmutableList; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.adb.connector.AdbPushdownSessionProperties; +import io.trino.plugin.base.expression.ConnectorExpressionPatterns; +import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.plugin.jdbc.QueryParameter; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.expression.Constant; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.RealType; +import io.trino.spi.type.Type; + +import java.util.Optional; + +public class AdbRewriteInexactNumericConstant + implements ConnectorExpressionRule +{ + private static final Pattern PATTERN = ConnectorExpressionPatterns.constant() + .with(ConnectorExpressionPatterns.type() + .matching(type -> type == RealType.REAL || type == DoubleType.DOUBLE)); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public boolean isEnabled(ConnectorSession session) + { + return AdbPushdownSessionProperties.isPushdownLiterals(session); + } + + @Override + public Optional rewrite(Constant constant, Captures captures, + RewriteContext context) + { + Type type = constant.getType(); + Object value = constant.getValue(); + return value == null ? Optional.empty() : + Optional.of(new ParameterizedExpression("?", + ImmutableList.of(new QueryParameter(type, Optional.of(value))))); + } +} diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteNullConstant.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteNullConstant.java new file mode 100644 index 0000000000000..1b367f803ceaa --- /dev/null +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteNullConstant.java @@ -0,0 +1,78 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.adb.connector.expression; + +import com.google.common.collect.ImmutableList; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.adb.connector.AdbPushdownSessionProperties; +import io.trino.plugin.adb.connector.datatype.mapper.DataTypeMapper; +import io.trino.plugin.base.expression.ConnectorExpressionPatterns; +import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; +import io.trino.spi.StandardErrorCode; +import io.trino.spi.TrinoException; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.expression.Constant; + +import java.util.Optional; + +public class AdbRewriteNullConstant + implements ConnectorExpressionRule +{ + private static final Pattern PATTERN = ConnectorExpressionPatterns.constant(); + private final DataTypeMapper dataTypeMapper; + + public AdbRewriteNullConstant(DataTypeMapper dataTypeMapper) + { + this.dataTypeMapper = dataTypeMapper; + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public boolean isEnabled(ConnectorSession session) + { + return AdbPushdownSessionProperties.isPushdownLiterals(session); + } + + @Override + public Optional rewrite(Constant constant, Captures captures, + RewriteContext context) + { + if (constant.getValue() != null) { + return Optional.empty(); + } + else { + try { + String returnType = + dataTypeMapper.toWriteMapping(context.getSession(), constant.getType()).getDataType(); + return Optional.of( + new ParameterizedExpression(String.format("CAST(NULL AS %s)", returnType), ImmutableList.of())); + } + catch (TrinoException e) { + if (e.getErrorCode() == StandardErrorCode.NOT_SUPPORTED.toErrorCode()) { + return Optional.empty(); + } + else { + throw e; + } + } + } + } +} diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/metadata/AdbMetadataDao.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/metadata/AdbMetadataDao.java index 4fa4d511bfc14..579ff2cff6f27 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/metadata/AdbMetadataDao.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/metadata/AdbMetadataDao.java @@ -26,5 +26,6 @@ int getSegmentCount(ConnectorSession session) boolean isSegmentedTable(ConnectorSession session, String objectName); - Map getTableProperties(ConnectorSession session, String objectName, IdentifierMapping identifierMapping); + Map getTableProperties(ConnectorSession session, String objectName, + IdentifierMapping identifierMapping); } diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/metadata/impl/AdbMetadataDaoImpl.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/metadata/impl/AdbMetadataDaoImpl.java index 615abb3122d96..8d7b0fbae2a6b 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/metadata/impl/AdbMetadataDaoImpl.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/metadata/impl/AdbMetadataDaoImpl.java @@ -58,7 +58,8 @@ public int getSegmentCount(ConnectorSession session) try { int segmentCount = 0; try (Connection connection = this.connectionFactory.openConnection(session); - ResultSet rs = connection.createStatement().executeQuery("SELECT MAX(content) + 1 FROM pg_catalog.gp_segment_configuration")) { + ResultSet rs = connection.createStatement() + .executeQuery("SELECT MAX(content) + 1 FROM pg_catalog.gp_segment_configuration")) { if (!rs.next()) { return segmentCount; } @@ -87,17 +88,20 @@ public boolean isSegmentedTable(ConnectorSession session, String objectName) return false; } String distribution = rs.getString(1).trim(); - isDistributed = distribution.startsWith("DISTRIBUTED BY") || distribution.equals("DISTRIBUTED RANDOMLY"); + isDistributed = + distribution.startsWith("DISTRIBUTED BY") || distribution.equals("DISTRIBUTED RANDOMLY"); } return isDistributed; } catch (SQLException e) { - throw new TrinoException(JdbcErrorCode.JDBC_ERROR, "Failed to determine whether the table contains the segment ID column.", e); + throw new TrinoException(JdbcErrorCode.JDBC_ERROR, + "Failed to determine whether the table contains the segment ID column.", e); } } @Override - public Map getTableProperties(ConnectorSession session, String objectName, IdentifierMapping identifierMapping) + public Map getTableProperties(ConnectorSession session, String objectName, + IdentifierMapping identifierMapping) { String sql = String.format( "WITH oid AS (SELECT '%s'::regclass::oid oid)\n" + diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/AbstractExternalTableQueryFactory.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/AbstractExternalTableQueryFactory.java index 8ac40af9df07e..a5da0294dc712 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/AbstractExternalTableQueryFactory.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/AbstractExternalTableQueryFactory.java @@ -41,7 +41,7 @@ protected String createCommonQuery(GpfdistMetadata metadata) metadata.getGpfdistLocation(), metadata.getExternalTableFormatConfig().dataFormat().name(), metadata.getExternalTableFormatConfig().delimiter(), - metadata.getExternalTableFormatConfig().nullValue() == null ? "" : metadata.getExternalTableFormatConfig().nullValue(), + metadata.getExternalTableFormatConfig().nullValue(), metadata.getExternalTableFormatConfig().encoding()); } } diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/GpfdistModule.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/GpfdistModule.java index 171e8e03f2a70..f6b3177c0e4f2 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/GpfdistModule.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/GpfdistModule.java @@ -74,15 +74,22 @@ public class GpfdistModule public void setup(Binder binder) { install(new GpfdistServerModule()); - binder.bind(ExternalTableFormatConfigFactory.class).to(ExternalTableFormatConfigFactoryImpl.class).in(Scopes.SINGLETON); - Multibinder createExtTableQueryFactories = Multibinder.newSetBinder(binder, CreateExternalTableQueryFactory.class); - createExtTableQueryFactories.addBinding().to(CreateReadableExternalTableQueryFactory.class).in(Scopes.SINGLETON); - createExtTableQueryFactories.addBinding().to(CreateWritableExternalTableQueryFactory.class).in(Scopes.SINGLETON); - Multibinder insertDataQueryFactories = Multibinder.newSetBinder(binder, InsertDataQueryFactory.class); + binder.bind(ExternalTableFormatConfigFactory.class).to(ExternalTableFormatConfigFactoryImpl.class) + .in(Scopes.SINGLETON); + Multibinder createExtTableQueryFactories = + Multibinder.newSetBinder(binder, CreateExternalTableQueryFactory.class); + createExtTableQueryFactories.addBinding().to(CreateReadableExternalTableQueryFactory.class) + .in(Scopes.SINGLETON); + createExtTableQueryFactories.addBinding().to(CreateWritableExternalTableQueryFactory.class) + .in(Scopes.SINGLETON); + Multibinder insertDataQueryFactories = + Multibinder.newSetBinder(binder, InsertDataQueryFactory.class); insertDataQueryFactories.addBinding().to(InsertDataFromExternalTableQueryFactory.class).in(Scopes.SINGLETON); insertDataQueryFactories.addBinding().to(InsertDataToExternalTableQueryFactory.class).in(Scopes.SINGLETON); - OptionalBinder.newOptionalBinder(binder, ConnectorPageSinkProvider.class).setBinding().to(GpfdistPageSinkProvider.class).in(Scopes.SINGLETON); - OptionalBinder.newOptionalBinder(binder, ConnectorRecordSetProvider.class).setBinding().to(GpfdistRecordSetProvider.class).in(Scopes.SINGLETON); + OptionalBinder.newOptionalBinder(binder, ConnectorPageSinkProvider.class).setBinding() + .to(GpfdistPageSinkProvider.class).in(Scopes.SINGLETON); + OptionalBinder.newOptionalBinder(binder, ConnectorRecordSetProvider.class).setBinding() + .to(GpfdistRecordSetProvider.class).in(Scopes.SINGLETON); binder.bind(GpfdistLoadMetadataFactory.class).to(GpfdistLoadMetadataFactoryImpl.class).in(Scopes.SINGLETON); binder.bind(GpfdistUnloadMetadataFactory.class).to(GpfdistUnloadMetadataFactoryImpl.class).in(Scopes.SINGLETON); binder.bind(GpfdistLocationFactory.class).to(GpfdistLocationFactoryImpl.class).in(Scopes.SINGLETON); @@ -102,11 +109,13 @@ public void setup(Binder binder) ExportBinder.newExporter(binder) .export(RequestStats.class) .as(generator -> generator.generatedNameOf(RequestStats.class, "adb-gpfdist-server-request-stats")); - ExportBinder.newExporter(binder).export(HttpServer.class).as(generator -> generator.generatedNameOf(HttpServer.class, "adb-gpfdist-server")); + ExportBinder.newExporter(binder).export(HttpServer.class) + .as(generator -> generator.generatedNameOf(HttpServer.class, "adb-gpfdist-server")); binder.bind(HttpServer.class).toProvider(HttpServerProvider.class).in(Scopes.SINGLETON); } - public static MapBinder externalTableQueriesFactoryMap(Binder binder) + public static MapBinder externalTableQueriesFactoryMap( + Binder binder) { return newMapBinder(binder, ExternalTableType.class, CreateExternalTableQueryFactory.class); } @@ -130,7 +139,8 @@ public static ConnectionFactory getConnectionFactory(BaseJdbcConfig config, @Singleton public static NodeConfig getNodeConfig(GpfdistServerConfig config, NodeManager nodeManager) { - String internalHost = config.getServerHost() != null ? config.getServerHost() : nodeManager.getCurrentNode().getHost(); + String internalHost = + config.getServerHost() != null ? config.getServerHost() : nodeManager.getCurrentNode().getHost(); String externalHost = config.getServerExternalHost() != null ? config.getServerExternalHost() : internalHost; return new NodeConfig() .setEnvironment("adb_gpfdist") diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageSinkProvider.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageSinkProvider.java index 793bf35ef9275..5ae189e5b7c07 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageSinkProvider.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageSinkProvider.java @@ -78,8 +78,10 @@ public GpfdistPageSinkProvider(@ForBaseJdbc JdbcClient client, this.rowEncoderFactory = rowEncoderFactory; this.externalTableFormatConfigFactory = externalTableFormatConfigFactory; this.loadQueryThreadExecutor = ExecutorServiceProvider.LOAD_DATA_QUERY_EXECUTOR_SERVICE; - Map externalTableQueryFactoryMap = createExternalTableQueryFactories.stream() - .collect(Collectors.toMap(CreateExternalTableQueryFactory::getExternalTableType, Function.identity())); + Map externalTableQueryFactoryMap = + createExternalTableQueryFactories.stream() + .collect(Collectors.toMap(CreateExternalTableQueryFactory::getExternalTableType, + Function.identity())); externalTableCreateQueryFactory = externalTableQueryFactoryMap.get(EXTERNAL_TABLE_TYPE); checkArgument(externalTableCreateQueryFactory != null, "failed to get writable table query factory by externalTableType %s", @@ -107,7 +109,8 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa ConnectorInsertTableHandle insertTableHandle, ConnectorPageSinkId pageSinkId) { - return createPageSinkInternal(transactionHandle, session, (ConnectorOutputTableHandle) insertTableHandle, pageSinkId); + return createPageSinkInternal(transactionHandle, session, (ConnectorOutputTableHandle) insertTableHandle, + pageSinkId); } private ConnectorPageSink createPageSinkInternal(ConnectorTransactionHandle transactionHandle, diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/metadata/ExternalTableFormatConfigFactoryImpl.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/metadata/ExternalTableFormatConfigFactoryImpl.java index 8ddbe7a8a2e23..6ccce9ac1beba 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/metadata/ExternalTableFormatConfigFactoryImpl.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/metadata/ExternalTableFormatConfigFactoryImpl.java @@ -36,7 +36,7 @@ public ExternalTableFormatConfig create() CsvFormatConfig csvFormatConfig = (CsvFormatConfig) dataFormatConfig; return new ExternalTableFormatConfig(csvFormatConfig.getDelimiter(), csvFormatConfig.getEncoding(), - csvFormatConfig.getNullValue(), + csvFormatConfig.getNullValue().orElse(""), dataFormatConfig.getDataFormat()); } else { diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/metadata/GpfdistLocationFactoryImpl.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/metadata/GpfdistLocationFactoryImpl.java index 6330da691c57d..7ce4bad0c15ba 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/metadata/GpfdistLocationFactoryImpl.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/metadata/GpfdistLocationFactoryImpl.java @@ -36,7 +36,8 @@ public GpfdistLocationFactoryImpl(GpfdistServerConfig config, HttpServerInfo htt public String create(String externalTableName, ExternalTableType externalTableType) { String protocol = config.isServerSslEnabled() ? "gpfdists" : "gpfdist"; - URI uri = config.isServerSslEnabled() ? httpServerInfo.getHttpsExternalUri() : httpServerInfo.getHttpExternalUri(); + URI uri = + config.isServerSslEnabled() ? httpServerInfo.getHttpsExternalUri() : httpServerInfo.getHttpExternalUri(); String host = uri.getHost(); int port = config.getServerExternalPort() > 0 ? config.getServerExternalPort() : uri.getPort(); return String.format("%s://%s:%d/gpfdist/%s/%s", diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/metadata/GpfdistMetadata.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/metadata/GpfdistMetadata.java index fd6463e3d77ea..d2b42b0fb016e 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/metadata/GpfdistMetadata.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/metadata/GpfdistMetadata.java @@ -75,7 +75,10 @@ public boolean equals(Object o) return false; } GpfdistMetadata that = (GpfdistMetadata) o; - return Objects.equals(sourceTable, that.sourceTable) && Objects.equals(columnNames, that.columnNames) && Objects.equals(dataTypes, that.dataTypes) && Objects.equals(externalTableFormatConfig, that.externalTableFormatConfig) && Objects.equals(gpfdistLocation, that.gpfdistLocation); + return Objects.equals(sourceTable, that.sourceTable) && Objects.equals(columnNames, that.columnNames) && + Objects.equals(dataTypes, that.dataTypes) && + Objects.equals(externalTableFormatConfig, that.externalTableFormatConfig) && + Objects.equals(gpfdistLocation, that.gpfdistLocation); } @Override diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/server/GpfdistResource.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/server/GpfdistResource.java index 8468379710963..2e08e64d05163 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/server/GpfdistResource.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/server/GpfdistResource.java @@ -81,7 +81,8 @@ public GpfdistResource(ContextManager writeContextManager, @GET @Produces("text/plain") @Path("/read/{tableName}") - public void get(@PathParam("tableName") String tableName, @Context HttpHeaders headers, @Suspended AsyncResponse asyncResponse) + public void get(@PathParam("tableName") String tableName, @Context HttpHeaders headers, + @Suspended AsyncResponse asyncResponse) { GpfdistReadableRequest request = GpfdistReadableRequest.create(tableName, headers.getRequestHeaders()); checkArgument(request.getGpProtocol() == GPFDIST_FOR_READ_PROTOCOL_VERSION, @@ -101,7 +102,8 @@ public void get(@PathParam("tableName") String tableName, @Context HttpHeaders h } } - private void processGetRequest(AsyncResponse asyncResponse, GpfdistReadableRequest request, WriteContext writeContext) + private void processGetRequest(AsyncResponse asyncResponse, GpfdistReadableRequest request, + WriteContext writeContext) { int bufferSizeInBytes = Long.valueOf(pluginConfig.getWriteBufferSize().toBytes()).intValue(); try (PipedOutputStream outputStream = new PipedOutputStream(); @@ -139,13 +141,15 @@ private Response.ResponseBuilder createOkGetResponseBuilder(GpfdistReadableReque @POST @Consumes("*/*") @Path("/write/{tableName}") - public void post(@PathParam("tableName") String tableName, InputStream data, @Context HttpHeaders headers, @Suspended AsyncResponse asyncResponse) + public void post(@PathParam("tableName") String tableName, InputStream data, @Context HttpHeaders headers, + @Suspended AsyncResponse asyncResponse) { try { GpfdistWritableRequest request = GpfdistWritableRequest.create(tableName, headers.getRequestHeaders()); log.debug("Received POST request: %s", request); checkArgument(request.getGpProtocol() == GPFDIST_FOR_WRITE_PROTOCOL_VERSION, - format("Gpfdist protocol version %s for write operation is supported", GPFDIST_FOR_WRITE_PROTOCOL_VERSION)); + format("Gpfdist protocol version %s for write operation is supported", + GPFDIST_FOR_WRITE_PROTOCOL_VERSION)); Optional readContextOptional = readContextManager.get(new ContextId(tableName)); if (readContextOptional.isEmpty()) { processNotFoundQueryRequest(tableName, asyncResponse, request); @@ -170,7 +174,8 @@ public void post(@PathParam("tableName") String tableName, InputStream data, @Co } } - private static void processNotFoundQueryRequest(String tableName, AsyncResponse asyncResponse, GpfdistWritableRequest request) + private static void processNotFoundQueryRequest(String tableName, AsyncResponse asyncResponse, + GpfdistWritableRequest request) { String errorMessage = "No active query for writeable table: " + tableName; asyncResponse.resume(Response.status(Response.Status.BAD_REQUEST.getStatusCode(), errorMessage) @@ -179,7 +184,8 @@ private static void processNotFoundQueryRequest(String tableName, AsyncResponse log.error("Failed to processed request: %s. " + errorMessage, request); } - private void processInitialRequest(AsyncResponse asyncResponse, ReadContext readContext, GpfdistWritableRequest request) + private void processInitialRequest(AsyncResponse asyncResponse, ReadContext readContext, + GpfdistWritableRequest request) { InputDataProcessor dataProcessor = inputDataProcessorFactory.create(readContext.getRowDecoder(), readContext.getRowProcessor()); @@ -191,7 +197,8 @@ private void processInitialRequest(AsyncResponse asyncResponse, ReadContext read log.debug("Request for initial data transferring completed successfully: %s", request); } - private void processDataRequest(InputStream data, AsyncResponse asyncResponse, ReadContext readContext, GpfdistWritableRequest request) + private void processDataRequest(InputStream data, AsyncResponse asyncResponse, ReadContext readContext, + GpfdistWritableRequest request) { executorService.submit(() -> { try { @@ -208,7 +215,8 @@ private void processDataRequest(InputStream data, AsyncResponse asyncResponse, R }); } - private void processTearDownRequest(AsyncResponse asyncResponse, ReadContext readContext, GpfdistWritableRequest request) + private void processTearDownRequest(AsyncResponse asyncResponse, ReadContext readContext, + GpfdistWritableRequest request) { GpfdistSegmentRequestProcessor processor = getSegmentProcessor(readContext, request.getSegmentId()); processor.stop(); @@ -221,7 +229,8 @@ private void processTearDownRequest(AsyncResponse asyncResponse, ReadContext rea private GpfdistSegmentRequestProcessor getSegmentProcessor(ReadContext readContext, Integer segmentId) { return Optional.ofNullable(readContext.getSegmentDataProcessors().get(segmentId)) - .orElseThrow(() -> new IllegalStateException("Failed to get segment request processor by segmentId: " + segmentId)); + .orElseThrow(() -> new IllegalStateException( + "Failed to get segment request processor by segmentId: " + segmentId)); } private void failWriteResponse(AsyncResponse asyncResponse, Exception e) diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/server/request/GpfdistWritableRequest.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/server/request/GpfdistWritableRequest.java index 72bef9a2f0fa7..42ac64f33f510 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/server/request/GpfdistWritableRequest.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/server/request/GpfdistWritableRequest.java @@ -82,7 +82,8 @@ public static GpfdistWritableRequest create(String tableName, MultivaluedMap Integer.parseInt(v.getFirst())) - .orElseThrow(() -> new IllegalArgumentException("Request header not found: " + X_GP_SEGMENT_ID)), + .orElseThrow( + () -> new IllegalArgumentException("Request header not found: " + X_GP_SEGMENT_ID)), Optional.ofNullable(values.get(X_GP_SEGMENT_COUNT)) .map(v -> Integer.parseInt(v.getFirst())), Optional.ofNullable(values.get(X_GP_LINE_DELIM_LENGTH)) diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/unload/process/GpfdistCsvDataProcessor.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/unload/process/GpfdistCsvDataProcessor.java index 90bd163048c16..d13985269aa0b 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/unload/process/GpfdistCsvDataProcessor.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/unload/process/GpfdistCsvDataProcessor.java @@ -56,7 +56,8 @@ public GpfdistCsvDataProcessor(DataFormatConfig dataFormatConfig, public void process(InputStream dataStream) { try { - try (InputStreamReader streamReader = new InputStreamReader(dataStream, Charset.forName(dataFormatConfig.getEncoding())); + try (InputStreamReader streamReader = new InputStreamReader(dataStream, + Charset.forName(dataFormatConfig.getEncoding())); BufferedReader bufferedStreamReader = new BufferedReader(streamReader)) { CSVReader csvReader = new CSVReaderBuilder(bufferedStreamReader) .withCSVParser(new RFC4180ParserBuilder() @@ -85,7 +86,8 @@ public void process(InputStream dataStream) } } catch (Throwable e) { - throw new TrinoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, "Failed to process input data: " + e.getMessage(), e); + throw new TrinoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, + "Failed to process input data: " + e.getMessage(), e); } } diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/unload/process/GpfdistRecordCursor.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/unload/process/GpfdistRecordCursor.java index 82913816372c4..23ab67a1504d8 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/unload/process/GpfdistRecordCursor.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/unload/process/GpfdistRecordCursor.java @@ -135,7 +135,8 @@ private boolean isDataNotProcessed() { //check that there are no rows in queue and segment processors are not finished yet return rowProcessingService.isEmpty() - && readContext.getSegmentDataProcessors().values().stream().anyMatch(req -> req.getStatus() != SegmentRequestStatus.FINISHED); + && readContext.getSegmentDataProcessors().values().stream() + .anyMatch(req -> req.getStatus() != SegmentRequestStatus.FINISHED); } private boolean isDataTransferNotInitialized() diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/unload/process/GpfdistRecordSetProvider.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/unload/process/GpfdistRecordSetProvider.java index 20b6d34ab7d11..b8badc6f3b293 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/unload/process/GpfdistRecordSetProvider.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/unload/process/GpfdistRecordSetProvider.java @@ -15,7 +15,6 @@ import com.google.inject.Inject; import io.trino.plugin.adb.AdbPluginConfig; -import io.trino.plugin.adb.connector.AdbJdbcSplit; import io.trino.plugin.adb.connector.AdbSessionProperties; import io.trino.plugin.adb.connector.AdbSqlClient; import io.trino.plugin.adb.connector.decode.RowDecoder; @@ -32,6 +31,7 @@ import io.trino.plugin.adb.connector.protocol.gpfdist.metadata.GpfdistUnloadMetadataFactory; import io.trino.plugin.adb.connector.protocol.gpfdist.unload.context.ReadContext; import io.trino.plugin.adb.connector.protocol.gpfdist.unload.query.GpfdistUnloadDataTransferQueryExecutor; +import io.trino.plugin.adb.connector.table.AdbJdbcSplit; import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.ForRecordCursor; import io.trino.plugin.jdbc.JdbcClient; @@ -88,8 +88,10 @@ public GpfdistRecordSetProvider(@ForBaseJdbc JdbcClient client, this.rowDecoderFactory = rowDecoderFactory; this.contextManager = contextManager; this.unloadQueryThreadExecutor = ExecutorServiceProvider.LOAD_DATA_QUERY_EXECUTOR_SERVICE; - Map externalTableQueryFactoryMap = createExternalTableQueryFactories.stream() - .collect(Collectors.toMap(CreateExternalTableQueryFactory::getExternalTableType, Function.identity())); + Map externalTableQueryFactoryMap = + createExternalTableQueryFactories.stream() + .collect(Collectors.toMap(CreateExternalTableQueryFactory::getExternalTableType, + Function.identity())); externalTableQueryFactory = externalTableQueryFactoryMap.get(EXTERNAL_TABLE_TYPE); checkArgument(externalTableQueryFactory != null, "failed to get writable table query factory by externalTableType %s", diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/unload/query/CreateWritableExternalTableQueryFactory.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/unload/query/CreateWritableExternalTableQueryFactory.java index adfb201d7b5e9..8a8d117684397 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/unload/query/CreateWritableExternalTableQueryFactory.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/unload/query/CreateWritableExternalTableQueryFactory.java @@ -13,11 +13,11 @@ */ package io.trino.plugin.adb.connector.protocol.gpfdist.unload.query; -import io.trino.plugin.adb.connector.AdbJdbcSplit; import io.trino.plugin.adb.connector.protocol.gpfdist.AbstractExternalTableQueryFactory; import io.trino.plugin.adb.connector.protocol.gpfdist.metadata.ExternalTableType; import io.trino.plugin.adb.connector.protocol.gpfdist.metadata.GpfdistMetadata; import io.trino.plugin.adb.connector.protocol.gpfdist.metadata.GpfdistUnloadMetadata; +import io.trino.plugin.adb.connector.table.AdbJdbcSplit; import java.util.HashSet; import java.util.List; diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/CollationAwareQueryBuilder.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/query/CollationAwareQueryBuilder.java similarity index 82% rename from plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/CollationAwareQueryBuilder.java rename to plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/query/CollationAwareQueryBuilder.java index 586a8a04f5d84..eb443f4084e18 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/CollationAwareQueryBuilder.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/query/CollationAwareQueryBuilder.java @@ -11,9 +11,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.plugin.adb.connector; +package io.trino.plugin.adb.connector.query; import com.google.inject.Inject; +import io.trino.plugin.adb.connector.AdbSessionProperties; +import io.trino.plugin.adb.connector.AdbSqlClient; import io.trino.plugin.jdbc.DefaultQueryBuilder; import io.trino.plugin.jdbc.JdbcClient; import io.trino.plugin.jdbc.JdbcColumnHandle; @@ -38,9 +40,11 @@ public CollationAwareQueryBuilder(RemoteQueryModifier queryModifier) super(queryModifier); } - protected String formatJoinCondition(JdbcClient client, String leftRelationAlias, String rightRelationAlias, JdbcJoinCondition condition) + protected String formatJoinCondition(JdbcClient client, String leftRelationAlias, String rightRelationAlias, + JdbcJoinCondition condition) { - boolean isCollatable = Stream.of(condition.getLeftColumn(), condition.getRightColumn()).anyMatch(AdbSqlClient::isCollatable); + boolean isCollatable = + Stream.of(condition.getLeftColumn(), condition.getRightColumn()).anyMatch(AdbSqlClient::isCollatable); return isCollatable ? String.format( "%s.%s COLLATE \"C\" %s %s.%s COLLATE \"C\"", @@ -65,10 +69,12 @@ protected String toPredicate( { if (AdbSqlClient.isCollatable(column) && AdbSessionProperties.isEnableStringPushdownWithCollate(session)) { accumulator.accept(new QueryParameter(jdbcType, type, Optional.of(value))); - return String.format("%s %s %s COLLATE \"C\"", client.quoted(column.getColumnName()), operator, writeFunction.getBindExpression()); + return String.format("%s %s %s COLLATE \"C\"", client.quoted(column.getColumnName()), operator, + writeFunction.getBindExpression()); } else { - return super.toPredicate(client, session, column, jdbcType, type, writeFunction, operator, value, accumulator); + return super.toPredicate(client, session, column, jdbcType, type, writeFunction, operator, value, + accumulator); } } } diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbColumnMapping.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/AdbColumnMapping.java similarity index 94% rename from plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbColumnMapping.java rename to plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/AdbColumnMapping.java index 1ef610415be7f..9e77c1f71c61e 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbColumnMapping.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/AdbColumnMapping.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.plugin.adb.connector; +package io.trino.plugin.adb.connector.table; import io.trino.plugin.adb.connector.datatype.ColumnDataType; import io.trino.plugin.jdbc.ColumnMapping; diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbJdbcSplit.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/AdbJdbcSplit.java similarity index 98% rename from plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbJdbcSplit.java rename to plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/AdbJdbcSplit.java index b81ae2abd06b2..6d0b31238cb90 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbJdbcSplit.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/AdbJdbcSplit.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.plugin.adb.connector; +package io.trino.plugin.adb.connector.table; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/AdbTableProperties.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/AdbTableProperties.java index 5c2d970eaf59a..cf22c608ac6cb 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/AdbTableProperties.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/AdbTableProperties.java @@ -58,19 +58,25 @@ public AdbTableProperties(AdbCreateTableStorageConfig config) false, value -> value, value -> value)) - .add(PropertyMetadata.booleanProperty(APPEND_OPTIMIZED_PROPERTY, "Whether table is append-optimized", config.getAppendOptimized(), false)) - .add(PropertyMetadata.integerProperty(BLOCK_SIZE_PROPERTY, "Block size for append-optimized tables", config.getBlockSize(), false)) + .add(PropertyMetadata.booleanProperty(APPEND_OPTIMIZED_PROPERTY, "Whether table is append-optimized", + config.getAppendOptimized(), false)) + .add(PropertyMetadata.integerProperty(BLOCK_SIZE_PROPERTY, "Block size for append-optimized tables", + config.getBlockSize(), false)) .add(PropertyMetadata.enumProperty( - ORIENTATION_PROPERTY, "Table orientation. Valid values: column, row", AdbTableStorageOrientation.class, config.getOrientation(), false)) - .add(PropertyMetadata.booleanProperty(CHECKSUM_PROPERTY, "Whether table is append-optimized", config.getChecksum(), false)) + ORIENTATION_PROPERTY, "Table orientation. Valid values: column, row", + AdbTableStorageOrientation.class, config.getOrientation(), false)) + .add(PropertyMetadata.booleanProperty(CHECKSUM_PROPERTY, "Whether table is append-optimized", + config.getChecksum(), false)) .add(PropertyMetadata.enumProperty( COMPRESS_TYPE_PROPERTY, "Compression type. Valid values: zlib, zstd, rle_type, none", AdbTableStorageCompressType.class, config.getCompressType(), false)) - .add(PropertyMetadata.integerProperty(COMPRESS_LEVEL_PROPERTY, "Compression level", config.getCompressLevel(), false)) - .add(PropertyMetadata.integerProperty(FILL_FACTOR_PROPERTY, "Fill factor", config.getFillFactor(), false)) + .add(PropertyMetadata.integerProperty(COMPRESS_LEVEL_PROPERTY, "Compression level", + config.getCompressLevel(), false)) + .add(PropertyMetadata.integerProperty(FILL_FACTOR_PROPERTY, "Fill factor", config.getFillFactor(), + false)) .build(); } @@ -103,7 +109,8 @@ public static Optional getBlockSize(Map tableProperties public static Optional getOrientation(Map tableProperties) { - return Optional.ofNullable(tableProperties.get(ORIENTATION_PROPERTY)).map(AdbTableStorageOrientation.class::cast); + return Optional.ofNullable(tableProperties.get(ORIENTATION_PROPERTY)) + .map(AdbTableStorageOrientation.class::cast); } public static Optional getChecksum(Map tableProperties) @@ -113,7 +120,8 @@ public static Optional getChecksum(Map tableProperties) public static Optional getCompressType(Map tableProperties) { - return Optional.ofNullable(tableProperties.get(COMPRESS_TYPE_PROPERTY)).map(AdbTableStorageCompressType.class::cast); + return Optional.ofNullable(tableProperties.get(COMPRESS_TYPE_PROPERTY)) + .map(AdbTableStorageCompressType.class::cast); } public static Optional getCompressLevel(Map tableProperties) diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/SplitSourceManagerImpl.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/SplitSourceManagerImpl.java index 9e771614a04c6..26db3c7a5da73 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/SplitSourceManagerImpl.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/SplitSourceManagerImpl.java @@ -14,7 +14,6 @@ package io.trino.plugin.adb.connector.table; import com.google.inject.Inject; -import io.trino.plugin.adb.connector.AdbJdbcSplit; import io.trino.plugin.adb.connector.AdbSessionProperties; import io.trino.plugin.adb.connector.metadata.AdbMetadataDao; import io.trino.plugin.base.mapping.IdentifierMapping; @@ -54,7 +53,13 @@ public SplitSourceManagerImpl(AdbMetadataDao metadata, IdentifierMapping identif public ConnectorSplitSource getSplits(ConnectorSession session, JdbcTableHandle tableHandle) { int parallelism = getSplitParallelism(session, tableHandle); - return new FixedSplitSource(createAdbSplits(session, tableHandle, segmentedSplits(parallelism))); + List splits = segmentedSplits(parallelism); + if (tableHandle.isNamedRelation()) { + return new FixedSplitSource(createAdbSplits(session, tableHandle, splits)); + } + else { + return new FixedSplitSource(splits); + } } private int getSplitParallelism(ConnectorSession session, JdbcTableHandle tableHandle) @@ -102,7 +107,8 @@ private List createAdbSplits(ConnectorSession session, Map tableProperties = this.metadata.getTableProperties(session, objectName, identifierMapping); List distributionInfo = AdbTableProperties.getDistributedBy(tableProperties).orElse(List.of()); return splits.stream() - .map(split -> new AdbJdbcSplit(distributionInfo, split.getAdditionalPredicate(), split.getDynamicFilter())) + .map(split -> new AdbJdbcSplit(distributionInfo, split.getAdditionalPredicate(), + split.getDynamicFilter())) .toList(); } diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/StatisticsManager.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/StatisticsManager.java index 034ee4d8ff68b..885fd1b5e626d 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/StatisticsManager.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/StatisticsManager.java @@ -22,5 +22,6 @@ public interface StatisticsManager { - TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHandle handle, List columns); + TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHandle handle, + List columns); } diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/StatisticsManagerImpl.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/StatisticsManagerImpl.java index e505c4628a2aa..2f1e6393fe9f1 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/StatisticsManagerImpl.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/table/StatisticsManagerImpl.java @@ -75,7 +75,8 @@ public TableStatistics getTableStatistics(ConnectorSession session, } } - private TableStatistics readTableStatistics(ConnectorSession session, JdbcTableHandle table, List columns) + private TableStatistics readTableStatistics(ConnectorSession session, JdbcTableHandle table, + List columns) throws SQLException { checkArgument(table.isNamedRelation(), "Relation is not a table: %s", table); @@ -101,8 +102,10 @@ private TableStatistics readTableStatistics(ConnectorSession session, JdbcTableH } RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName(); - Map columnStatistics = statisticsDao.getColumnStatistics(remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName()).stream() - .collect(toImmutableMap(ColumnStatisticsResult::columnName, identity())); + Map columnStatistics = + statisticsDao.getColumnStatistics(remoteTableName.getSchemaName().orElse(null), + remoteTableName.getTableName()).stream() + .collect(toImmutableMap(ColumnStatisticsResult::columnName, identity())); for (JdbcColumnHandle column : columns) { ColumnStatisticsResult result = columnStatistics.get(column.getColumnName()); @@ -126,7 +129,8 @@ private TableStatistics readTableStatistics(ConnectorSession session, JdbcTableH .setDataSize(result.averageColumnLength() .flatMap(averageColumnLength -> result.nullsFraction().map(nullsFraction -> - Estimate.of(1.0 * averageColumnLength * rowCount * (1 - nullsFraction)))) + Estimate.of( + 1.0 * averageColumnLength * rowCount * (1 - nullsFraction)))) .orElseGet(Estimate::unknown)) .build(); @@ -147,7 +151,8 @@ private static Optional readRowCountTableStat(StatisticsDao statisticsDao, return Optional.empty(); } if (statisticsDao.isPartitionedTable(schemaName, remoteTableName.getTableName())) { - Optional partitionedTableRowCount = statisticsDao.getRowCountPartitionedTableFromPgClass(schemaName, remoteTableName.getTableName()); + Optional partitionedTableRowCount = + statisticsDao.getRowCountPartitionedTableFromPgClass(schemaName, remoteTableName.getTableName()); if (partitionedTableRowCount.isPresent()) { return partitionedTableRowCount; } @@ -183,7 +188,8 @@ Optional getRowCountFromPgClass(String schema, String tableName) Optional getRowCountFromPgStat(String schema, String tableName) { - return handle.createQuery("SELECT n_live_tup FROM pg_stat_all_tables WHERE schemaname = :schema AND relname = :table_name") + return handle.createQuery( + "SELECT n_live_tup FROM pg_stat_all_tables WHERE schemaname = :schema AND relname = :table_name") .bind("schema", schema) .bind("table_name", tableName) .mapTo(Long.class) @@ -221,7 +227,8 @@ Optional getRowCountPartitionedTableFromPgStats(String schema, String tabl List getColumnStatistics(String schema, String tableName) { - return handle.createQuery("SELECT attname, null_frac, n_distinct, avg_width FROM pg_stats WHERE schemaname = :schema AND tablename = :table_name") + return handle.createQuery( + "SELECT attname, null_frac, n_distinct, avg_width FROM pg_stats WHERE schemaname = :schema AND tablename = :table_name") .bind("schema", schema) .bind("table_name", tableName) .map((rs, ctx) -> new ColumnStatisticsResult( From d591a023a6df1da90cb878b3712a303623b81a0f Mon Sep 17 00:00:00 2001 From: avv Date: Thu, 28 Nov 2024 19:28:12 +0500 Subject: [PATCH 2/4] ADH-5240 - refactored page sink --- .../gpfdist/load/context/WriteContext.java | 36 ++++---- .../process/GpfdistPageProcessorProvider.java | 87 +++++++++++++++++++ .../gpfdist/load/process/GpfdistPageSink.java | 25 +----- .../load/process/GpfdistPageSinkProvider.java | 12 ++- .../load/process/PageProcessorProvider.java | 27 ++++++ .../gpfdist/server/GpfdistResource.java | 29 ++++--- 6 files changed, 159 insertions(+), 57 deletions(-) create mode 100644 plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageProcessorProvider.java create mode 100644 plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/PageProcessorProvider.java diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/context/WriteContext.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/context/WriteContext.java index 6d41c662350e6..aa6b081375e50 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/context/WriteContext.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/context/WriteContext.java @@ -18,11 +18,11 @@ import io.trino.plugin.adb.connector.encode.RowEncoder; import io.trino.plugin.adb.connector.protocol.gpfdist.Context; import io.trino.plugin.adb.connector.protocol.gpfdist.load.PageProcessor; +import io.trino.plugin.adb.connector.protocol.gpfdist.load.process.GpfdistPageProcessorProvider; import io.trino.plugin.adb.connector.protocol.gpfdist.metadata.ContextId; import io.trino.plugin.adb.connector.protocol.gpfdist.metadata.GpfdistLoadMetadata; import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; @@ -32,31 +32,33 @@ public class WriteContext implements Context { private static final Logger log = Logger.get(WriteContext.class); - private final ContextId id; - private final GpfdistLoadMetadata metadata; - private final ConcurrentLinkedQueue pageProcessors = new ConcurrentLinkedQueue<>(); private final AtomicReference adbQueryException = new AtomicReference<>(); private final AtomicLong completedBytes = new AtomicLong(); private final AtomicLong memoryUsage = new AtomicLong(); - private final AtomicBoolean isReadyForWrite = new AtomicBoolean(false); + private final AtomicReference error = new AtomicReference<>(); + private final ContextId id; + private final GpfdistLoadMetadata metadata; private final RowEncoder rowEncoder; private final DataSize writeBufferSize; - private final AtomicReference error = new AtomicReference<>(); + private final GpfdistPageProcessorProvider pageProcessorProvider; - public WriteContext(GpfdistLoadMetadata metadata, RowEncoder rowEncoder, DataSize writeBufferSize) + public WriteContext(GpfdistLoadMetadata metadata, RowEncoder rowEncoder, DataSize writeBufferSize, + GpfdistPageProcessorProvider pageProcessorProvider) { - this(new ContextId(metadata.getSourceTable()), metadata, rowEncoder, writeBufferSize); + this(new ContextId(metadata.getSourceTable()), metadata, rowEncoder, writeBufferSize, pageProcessorProvider); } public WriteContext(ContextId id, GpfdistLoadMetadata metadata, RowEncoder rowEncoder, - DataSize writeBufferSize) + DataSize writeBufferSize, + GpfdistPageProcessorProvider pageProcessorProvider) { this.id = id; this.metadata = metadata; this.rowEncoder = rowEncoder; this.writeBufferSize = writeBufferSize; + this.pageProcessorProvider = pageProcessorProvider; } @Override @@ -80,11 +82,6 @@ public AtomicReference getAdbQueryException() return adbQueryException; } - public ConcurrentLinkedQueue getPageProcessors() - { - return pageProcessors; - } - public AtomicLong getCompletedBytes() { return completedBytes; @@ -95,11 +92,6 @@ public AtomicLong getMemoryUsage() return memoryUsage; } - public AtomicBoolean getIsReadyForWrite() - { - return isReadyForWrite; - } - public DataSize getWriteBufferSize() { return writeBufferSize; @@ -110,9 +102,15 @@ public AtomicReference getError() return error; } + public GpfdistPageProcessorProvider getPageProcessorProvider() + { + return pageProcessorProvider; + } + @Override public void close() { + ConcurrentLinkedQueue pageProcessors = pageProcessorProvider.getAll(); StringBuilder sb = new StringBuilder(); pageProcessors.forEach(processor -> { try { diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageProcessorProvider.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageProcessorProvider.java new file mode 100644 index 0000000000000..0f1d2da9330b3 --- /dev/null +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageProcessorProvider.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.adb.connector.protocol.gpfdist.load.process; + +import io.trino.plugin.adb.connector.protocol.gpfdist.load.PageProcessor; + +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; + +import static java.lang.String.format; + +public class GpfdistPageProcessorProvider + implements PageProcessorProvider +{ + private static final long ADB_SEGMENT_WAIT_TIMEOUT = 60000L; + private final ConcurrentLinkedQueue pageProcessors = new ConcurrentLinkedQueue<>(); + private final AtomicBoolean isReadyForProcessing = new AtomicBoolean(false); + private final ReentrantLock lock = new ReentrantLock(); + private final Condition isReadyForProcessingCondition = lock.newCondition(); + + public GpfdistPageProcessorProvider() + { + } + + @Override + public void add(PageProcessor processor) + { + lock.lock(); + try { + pageProcessors.add(processor); + isReadyForProcessing.set(true); + isReadyForProcessingCondition.signalAll(); + } + finally { + lock.unlock(); + } + } + + @Override + public PageProcessor take() + { + lock.lock(); + try { + if (!isReadyForProcessing.get()) { + long startTime = System.currentTimeMillis(); + while (pageProcessors.isEmpty()) { + try { + if (System.currentTimeMillis() - startTime > ADB_SEGMENT_WAIT_TIMEOUT) { + throw new RuntimeException( + format("Timeout :%d ms waiting for segments responses is exceeded", + ADB_SEGMENT_WAIT_TIMEOUT)); + } + isReadyForProcessingCondition.await(); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + PageProcessor pageProcessor = pageProcessors.poll(); + pageProcessors.offer(pageProcessor); + return pageProcessor; + } + finally { + lock.unlock(); + } + } + + @Override + public ConcurrentLinkedQueue getAll() + { + return pageProcessors; + } +} diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageSink.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageSink.java index 4de726c4acdfb..9719b6ca47443 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageSink.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageSink.java @@ -35,7 +35,6 @@ public class GpfdistPageSink implements ConnectorPageSink { private static final Logger log = Logger.get(GpfdistPageSink.class); - private static final long ADB_SEGMENT_WAIT_TIMEOUT = 60000L; private final ContextManager writeContextManager; private final WriteContext writeContext; private final CompletableFuture queryLoadFuture; @@ -57,10 +56,8 @@ public GpfdistPageSink(ContextManager writeContextManager, public CompletableFuture appendPage(Page page) { pageProcessingFuture = CompletableFuture.runAsync(() -> { - waitForProcessors(); if (writeContext.getAdbQueryException().get() == null) { - PageProcessor pageProcessor = writeContext.getPageProcessors().poll(); - writeContext.getPageProcessors().offer(pageProcessor); + PageProcessor pageProcessor = writeContext.getPageProcessorProvider().take(); pageProcessor.process(page); } else { @@ -70,26 +67,6 @@ public CompletableFuture appendPage(Page page) return pageProcessingFuture; } - private void waitForProcessors() - { - try { - if (!writeContext.getIsReadyForWrite().get()) { - long startTime = System.currentTimeMillis(); - while (writeContext.getPageProcessors().isEmpty()) { - if (System.currentTimeMillis() - startTime > ADB_SEGMENT_WAIT_TIMEOUT) { - throw new TrinoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, - "Timed out after waiting for ${ADB_SEGMENT_WAIT_TIMEOUT} ms for segments"); - } - Thread.sleep(100L); - } - writeContext.getIsReadyForWrite().set(true); - } - } - catch (InterruptedException e) { - throw new TrinoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, e); - } - } - @Override public CompletableFuture> finish() { diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageSinkProvider.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageSinkProvider.java index 793bf35ef9275..5bb7d8743aea8 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageSinkProvider.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageSinkProvider.java @@ -78,8 +78,10 @@ public GpfdistPageSinkProvider(@ForBaseJdbc JdbcClient client, this.rowEncoderFactory = rowEncoderFactory; this.externalTableFormatConfigFactory = externalTableFormatConfigFactory; this.loadQueryThreadExecutor = ExecutorServiceProvider.LOAD_DATA_QUERY_EXECUTOR_SERVICE; - Map externalTableQueryFactoryMap = createExternalTableQueryFactories.stream() - .collect(Collectors.toMap(CreateExternalTableQueryFactory::getExternalTableType, Function.identity())); + Map externalTableQueryFactoryMap = + createExternalTableQueryFactories.stream() + .collect(Collectors.toMap(CreateExternalTableQueryFactory::getExternalTableType, + Function.identity())); externalTableCreateQueryFactory = externalTableQueryFactoryMap.get(EXTERNAL_TABLE_TYPE); checkArgument(externalTableCreateQueryFactory != null, "failed to get writable table query factory by externalTableType %s", @@ -107,7 +109,8 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa ConnectorInsertTableHandle insertTableHandle, ConnectorPageSinkId pageSinkId) { - return createPageSinkInternal(transactionHandle, session, (ConnectorOutputTableHandle) insertTableHandle, pageSinkId); + return createPageSinkInternal(transactionHandle, session, (ConnectorOutputTableHandle) insertTableHandle, + pageSinkId); } private ConnectorPageSink createPageSinkInternal(ConnectorTransactionHandle transactionHandle, @@ -122,7 +125,8 @@ private ConnectorPageSink createPageSinkInternal(ConnectorTransactionHandle tran WriteContext writeContext = new WriteContext( loadMetadata, rowEncoderFactory.create(session, loadMetadata.getDataTypes()), - pluginConfig.getWriteBufferSize()); + pluginConfig.getWriteBufferSize(), + new GpfdistPageProcessorProvider()); DataTransferQueryExecutor loadDataExecutor = new GpfdistLoadDataTransferQueryExecutor(client, session, loadQueryThreadExecutor, diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/PageProcessorProvider.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/PageProcessorProvider.java new file mode 100644 index 0000000000000..cca2f093e9c0e --- /dev/null +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/PageProcessorProvider.java @@ -0,0 +1,27 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.adb.connector.protocol.gpfdist.load.process; + +import io.trino.plugin.adb.connector.protocol.gpfdist.load.PageProcessor; + +import java.util.concurrent.ConcurrentLinkedQueue; + +public interface PageProcessorProvider +{ + void add(PageProcessor processor); + + PageProcessor take(); + + ConcurrentLinkedQueue getAll(); +} diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/server/GpfdistResource.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/server/GpfdistResource.java index 72e4a2a2ec0f9..8ed1e85e984f1 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/server/GpfdistResource.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/server/GpfdistResource.java @@ -81,7 +81,8 @@ public GpfdistResource(ContextManager writeContextManager, @GET @Produces("text/plain") @Path("/read/{tableName}") - public void get(@PathParam("tableName") String tableName, @Context HttpHeaders headers, @Suspended AsyncResponse asyncResponse) + public void get(@PathParam("tableName") String tableName, @Context HttpHeaders headers, + @Suspended AsyncResponse asyncResponse) { GpfdistReadableRequest request = GpfdistReadableRequest.create(tableName, headers.getRequestHeaders()); checkArgument(request.getGpProtocol() == GPFDIST_FOR_READ_PROTOCOL_VERSION, @@ -101,7 +102,8 @@ public void get(@PathParam("tableName") String tableName, @Context HttpHeaders h } } - private void processGetRequest(AsyncResponse asyncResponse, GpfdistReadableRequest request, WriteContext writeContext) + private void processGetRequest(AsyncResponse asyncResponse, GpfdistReadableRequest request, + WriteContext writeContext) { int bufferSizeInBytes = Long.valueOf(pluginConfig.getWriteBufferSize().toBytes()).intValue(); try (PipedOutputStream outputStream = new PipedOutputStream(); @@ -110,7 +112,7 @@ private void processGetRequest(AsyncResponse asyncResponse, GpfdistReadableReque request, writeContext, new GpfdistPageSerializer(writeContext.getMetadata().getDataTypes(), writeContext.getRowEncoder())); - writeContext.getPageProcessors().add(gpfdistPageProcessor); + writeContext.getPageProcessorProvider().add(gpfdistPageProcessor); asyncResponse.resume(createOkGetResponseBuilder(request) .entity(inputStream) .build()); @@ -139,13 +141,15 @@ private Response.ResponseBuilder createOkGetResponseBuilder(GpfdistReadableReque @POST @Consumes("*/*") @Path("/write/{tableName}") - public void post(@PathParam("tableName") String tableName, InputStream data, @Context HttpHeaders headers, @Suspended AsyncResponse asyncResponse) + public void post(@PathParam("tableName") String tableName, InputStream data, @Context HttpHeaders headers, + @Suspended AsyncResponse asyncResponse) { try { GpfdistWritableRequest request = GpfdistWritableRequest.create(tableName, headers.getRequestHeaders()); log.debug("Received POST request: %s", request); checkArgument(request.getGpProtocol() == GPFDIST_FOR_WRITE_PROTOCOL_VERSION, - format("Gpfdist protocol version %s for write operation is supported", GPFDIST_FOR_WRITE_PROTOCOL_VERSION)); + format("Gpfdist protocol version %s for write operation is supported", + GPFDIST_FOR_WRITE_PROTOCOL_VERSION)); Optional readContextOptional = readContextManager.get(new ContextId(tableName)); if (readContextOptional.isEmpty()) { processNotFoundQueryRequest(tableName, asyncResponse, request); @@ -170,7 +174,8 @@ public void post(@PathParam("tableName") String tableName, InputStream data, @Co } } - private static void processNotFoundQueryRequest(String tableName, AsyncResponse asyncResponse, GpfdistWritableRequest request) + private static void processNotFoundQueryRequest(String tableName, AsyncResponse asyncResponse, + GpfdistWritableRequest request) { String errorMessage = "No active query for writeable table: " + tableName; asyncResponse.resume(Response.status(Response.Status.BAD_REQUEST.getStatusCode(), errorMessage) @@ -179,7 +184,8 @@ private static void processNotFoundQueryRequest(String tableName, AsyncResponse log.error("Failed to processed request: %s. " + errorMessage, request); } - private void processInitialRequest(AsyncResponse asyncResponse, ReadContext readContext, GpfdistWritableRequest request) + private void processInitialRequest(AsyncResponse asyncResponse, ReadContext readContext, + GpfdistWritableRequest request) { InputDataProcessor dataProcessor = inputDataProcessorFactory.create(readContext.getRowDecoder(), readContext.getRowProcessingService()); @@ -191,7 +197,8 @@ private void processInitialRequest(AsyncResponse asyncResponse, ReadContext read log.debug("Request for initial data transferring completed successfully: %s", request); } - private void processDataRequest(InputStream data, AsyncResponse asyncResponse, ReadContext readContext, GpfdistWritableRequest request) + private void processDataRequest(InputStream data, AsyncResponse asyncResponse, ReadContext readContext, + GpfdistWritableRequest request) { executorService.submit(() -> { try { @@ -208,7 +215,8 @@ private void processDataRequest(InputStream data, AsyncResponse asyncResponse, R }); } - private void processTearDownRequest(AsyncResponse asyncResponse, ReadContext readContext, GpfdistWritableRequest request) + private void processTearDownRequest(AsyncResponse asyncResponse, ReadContext readContext, + GpfdistWritableRequest request) { GpfdistSegmentRequestProcessor processor = getSegmentProcessor(readContext, request.getSegmentId()); processor.stop(); @@ -221,7 +229,8 @@ private void processTearDownRequest(AsyncResponse asyncResponse, ReadContext rea private GpfdistSegmentRequestProcessor getSegmentProcessor(ReadContext readContext, Integer segmentId) { return Optional.ofNullable(readContext.getSegmentDataProcessors().get(segmentId)) - .orElseThrow(() -> new IllegalStateException("Failed to get segment request processor by segmentId: " + segmentId)); + .orElseThrow(() -> new IllegalStateException( + "Failed to get segment request processor by segmentId: " + segmentId)); } private void failWriteResponse(AsyncResponse asyncResponse, Exception e) From 4281e92e3d376e1b9e409ec85d400ccfbabf9838 Mon Sep 17 00:00:00 2001 From: avv Date: Fri, 29 Nov 2024 10:18:05 +0500 Subject: [PATCH 3/4] ADH-5240 - refactored page process provider --- .../{process => }/PageProcessorProvider.java | 10 +++++----- .../gpfdist/load/context/WriteContext.java | 6 +++--- .../process/GpfdistPageProcessorProvider.java | 20 ++++++++++++++++--- 3 files changed, 25 insertions(+), 11 deletions(-) rename plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/{process => }/PageProcessorProvider.java (73%) diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/PageProcessorProvider.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/PageProcessorProvider.java similarity index 73% rename from plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/PageProcessorProvider.java rename to plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/PageProcessorProvider.java index cca2f093e9c0e..b0b7f93559a04 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/PageProcessorProvider.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/PageProcessorProvider.java @@ -11,11 +11,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.trino.plugin.adb.connector.protocol.gpfdist.load.process; +package io.trino.plugin.adb.connector.protocol.gpfdist.load; -import io.trino.plugin.adb.connector.protocol.gpfdist.load.PageProcessor; - -import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.Queue; public interface PageProcessorProvider { @@ -23,5 +21,7 @@ public interface PageProcessorProvider PageProcessor take(); - ConcurrentLinkedQueue getAll(); + Queue getAll(); + + void clear(); } diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/context/WriteContext.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/context/WriteContext.java index aa6b081375e50..cb8bb8ed4d86e 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/context/WriteContext.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/context/WriteContext.java @@ -22,7 +22,7 @@ import io.trino.plugin.adb.connector.protocol.gpfdist.metadata.ContextId; import io.trino.plugin.adb.connector.protocol.gpfdist.metadata.GpfdistLoadMetadata; -import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.Queue; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; @@ -110,7 +110,7 @@ public GpfdistPageProcessorProvider getPageProcessorProvider() @Override public void close() { - ConcurrentLinkedQueue pageProcessors = pageProcessorProvider.getAll(); + Queue pageProcessors = pageProcessorProvider.getAll(); StringBuilder sb = new StringBuilder(); pageProcessors.forEach(processor -> { try { @@ -120,7 +120,7 @@ public void close() sb.append(format("Failed to stop page processor %s. Error: %s;", processor, e.getMessage())); } }); - pageProcessors.clear(); + pageProcessorProvider.clear(); if (!sb.isEmpty()) { throw new RuntimeException(sb.toString()); } diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageProcessorProvider.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageProcessorProvider.java index 0f1d2da9330b3..f7edae4453cca 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageProcessorProvider.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageProcessorProvider.java @@ -14,8 +14,10 @@ package io.trino.plugin.adb.connector.protocol.gpfdist.load.process; import io.trino.plugin.adb.connector.protocol.gpfdist.load.PageProcessor; +import io.trino.plugin.adb.connector.protocol.gpfdist.load.PageProcessorProvider; -import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.LinkedList; +import java.util.Queue; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.ReentrantLock; @@ -26,7 +28,7 @@ public class GpfdistPageProcessorProvider implements PageProcessorProvider { private static final long ADB_SEGMENT_WAIT_TIMEOUT = 60000L; - private final ConcurrentLinkedQueue pageProcessors = new ConcurrentLinkedQueue<>(); + private final Queue pageProcessors = new LinkedList<>(); private final AtomicBoolean isReadyForProcessing = new AtomicBoolean(false); private final ReentrantLock lock = new ReentrantLock(); private final Condition isReadyForProcessingCondition = lock.newCondition(); @@ -80,8 +82,20 @@ public PageProcessor take() } @Override - public ConcurrentLinkedQueue getAll() + public Queue getAll() { return pageProcessors; } + + @Override + public void clear() + { + lock.lock(); + try { + pageProcessors.clear(); + } + finally { + lock.unlock(); + } + } } From df757582bb140a1837ebf45fead0810c4e8b4b4b Mon Sep 17 00:00:00 2001 From: avv Date: Fri, 29 Nov 2024 10:18:57 +0500 Subject: [PATCH 4/4] ADH-5240 - refactored page process provider --- .../gpfdist/load/process/GpfdistPageProcessorProvider.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageProcessorProvider.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageProcessorProvider.java index f7edae4453cca..d9081b7104607 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageProcessorProvider.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/load/process/GpfdistPageProcessorProvider.java @@ -43,7 +43,6 @@ public void add(PageProcessor processor) lock.lock(); try { pageProcessors.add(processor); - isReadyForProcessing.set(true); isReadyForProcessingCondition.signalAll(); } finally { @@ -71,6 +70,7 @@ public PageProcessor take() throw new RuntimeException(e); } } + isReadyForProcessing.set(true); } PageProcessor pageProcessor = pageProcessors.poll(); pageProcessors.offer(pageProcessor);