diff --git a/src/main/java/com/snowflake/kafka/connector/internal/streaming/schemaevolution/ColumnInfos.java b/src/main/java/com/snowflake/kafka/connector/internal/streaming/schemaevolution/ColumnInfos.java index 1ee8b82bb..430fda2ec 100644 --- a/src/main/java/com/snowflake/kafka/connector/internal/streaming/schemaevolution/ColumnInfos.java +++ b/src/main/java/com/snowflake/kafka/connector/internal/streaming/schemaevolution/ColumnInfos.java @@ -12,6 +12,11 @@ public ColumnInfos(String columnType, String comments) { this.comments = comments; } + public ColumnInfos(String columnType) { + this.columnType = columnType; + this.comments = null; + } + public String getColumnType() { return columnType; } diff --git a/src/main/java/com/snowflake/kafka/connector/internal/streaming/schemaevolution/TableSchemaResolver.java b/src/main/java/com/snowflake/kafka/connector/internal/streaming/schemaevolution/TableSchemaResolver.java index 2f8bb10a5..e5ac01f97 100644 --- a/src/main/java/com/snowflake/kafka/connector/internal/streaming/schemaevolution/TableSchemaResolver.java +++ b/src/main/java/com/snowflake/kafka/connector/internal/streaming/schemaevolution/TableSchemaResolver.java @@ -1,14 +1,17 @@ package com.snowflake.kafka.connector.internal.streaming.schemaevolution; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Maps; +import com.google.common.collect.Streams; import com.snowflake.kafka.connector.Utils; import com.snowflake.kafka.connector.internal.SnowflakeErrors; import com.snowflake.kafka.connector.records.RecordService; import java.util.HashMap; import java.util.HashSet; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.JsonNode; import org.apache.kafka.connect.data.Field; import org.apache.kafka.connect.data.Schema; @@ -30,40 +33,81 @@ protected TableSchemaResolver(ColumnTypeMapper columnTypeMapper) { * With the list of columns, collect their data types from either the schema or the data itself * * @param record the sink record that contains the schema and actual data - * @param columnNames the names of the extra columns + * @param columnsToInclude the names of the columns to include in the schema * @return a Map object where the key is column name and value is ColumnInfos */ - public TableSchema resolveTableSchemaFromRecord(SinkRecord record, List columnNames) { - if (columnNames == null) { - return new TableSchema(new HashMap<>()); + public TableSchema resolveTableSchemaFromRecord( + SinkRecord record, List columnsToInclude) { + if (columnsToInclude == null || columnsToInclude.isEmpty()) { + return new TableSchema(ImmutableMap.of()); } - Map columnToType = new HashMap<>(); - Map schemaMap = getSchemaMapFromRecord(record); + + Set columnNamesSet = new HashSet<>(columnsToInclude); + + if (hasSchema(record)) { + return getTableSchemaFromRecordSchema(record, columnNamesSet); + } else { + return getTableSchemaFromJson(record, columnNamesSet); + } + } + + private boolean hasSchema(SinkRecord record) { + return record.valueSchema() != null + && record.valueSchema().fields() != null + && !record.valueSchema().fields().isEmpty(); + } + + private TableSchema getTableSchemaFromRecordSchema( + SinkRecord record, Set columnNamesSet) { JsonNode recordNode = RecordService.convertToJson(record.valueSchema(), record.value(), true); - Set columnNamesSet = new HashSet<>(columnNames); - - Iterator> fields = recordNode.fields(); - while (fields.hasNext()) { - Map.Entry field = fields.next(); - String colName = Utils.quoteNameIfNeeded(field.getKey()); - if (columnNamesSet.contains(colName)) { - ColumnInfos columnInfos; - if (schemaMap.isEmpty()) { - // No schema associated with the record, we will try to infer it based on the data - columnInfos = new ColumnInfos(inferDataTypeFromJsonObject(field.getValue()), null); - } else { - // Get from the schema - columnInfos = schemaMap.get(field.getKey()); - if (columnInfos == null) { - // only when the type of the value is unrecognizable for JAVA - throw SnowflakeErrors.ERROR_5022.getException( - "column: " + field.getKey() + " schemaMap: " + schemaMap); - } - } - columnToType.put(colName, columnInfos); - } + Map schemaMap = getFullSchemaMapFromRecord(record); + Map> columnsWitValue = + Streams.stream(recordNode.fields()) + .map(ColumnValuePair::from) + .filter(pair -> columnNamesSet.contains(pair.getQuotedColumnName())) + .collect( + Collectors.partitioningBy((pair -> schemaMap.containsKey(pair.getColumnName())))); + + List notFoundFieldsInSchema = columnsWitValue.get(false); + List foundFieldsInSchema = columnsWitValue.get(true); + + if (!notFoundFieldsInSchema.isEmpty()) { + throw SnowflakeErrors.ERROR_5022.getException( + "Columns not found in schema: " + + notFoundFieldsInSchema.stream() + .map(ColumnValuePair::getColumnName) + .collect(Collectors.toList()) + + ", schemaMap: " + + schemaMap); } - return new TableSchema(columnToType); + Map columnsInferredFromSchema = + foundFieldsInSchema.stream() + .map( + pair -> + Maps.immutableEntry( + Utils.quoteNameIfNeeded(pair.getQuotedColumnName()), + schemaMap.get(pair.getColumnName()))) + .collect( + Collectors.toMap( + Map.Entry::getKey, Map.Entry::getValue, (oldValue, newValue) -> newValue)); + return new TableSchema(columnsInferredFromSchema); + } + + private TableSchema getTableSchemaFromJson(SinkRecord record, Set columnNamesSet) { + JsonNode recordNode = RecordService.convertToJson(record.valueSchema(), record.value(), true); + Map columnsInferredFromJson = + Streams.stream(recordNode.fields()) + .map(ColumnValuePair::from) + .filter(pair -> columnNamesSet.contains(pair.getQuotedColumnName())) + .map( + pair -> + Maps.immutableEntry( + pair.getQuotedColumnName(), + new ColumnInfos(inferDataTypeFromJsonObject(pair.getJsonNode())))) + .collect( + Collectors.toMap( + Map.Entry::getKey, Map.Entry::getValue, (oldValue, newValue) -> newValue)); + return new TableSchema(columnsInferredFromJson); } /** @@ -72,7 +116,7 @@ public TableSchema resolveTableSchemaFromRecord(SinkRecord record, List * @param record the sink record that contains the schema and actual data * @return a Map object where the key is column name and value is ColumnInfos */ - private Map getSchemaMapFromRecord(SinkRecord record) { + private Map getFullSchemaMapFromRecord(SinkRecord record) { Map schemaMap = new HashMap<>(); Schema schema = record.valueSchema(); if (schema != null && schema.fields() != null) { @@ -104,4 +148,32 @@ private String inferDataTypeFromJsonObject(JsonNode value) { // Passing null to schemaName when there is no schema information return columnTypeMapper.mapToColumnType(schemaType); } + + private static class ColumnValuePair { + private final String columnName; + private final String quotedColumnName; + private final JsonNode jsonNode; + + public static ColumnValuePair from(Map.Entry field) { + return new ColumnValuePair(field.getKey(), field.getValue()); + } + + private ColumnValuePair(String columnName, JsonNode jsonNode) { + this.columnName = columnName; + this.quotedColumnName = Utils.quoteNameIfNeeded(columnName); + this.jsonNode = jsonNode; + } + + public String getColumnName() { + return columnName; + } + + public String getQuotedColumnName() { + return quotedColumnName; + } + + public JsonNode getJsonNode() { + return jsonNode; + } + } }