From 869aa4eca5228b6bf9fda29e81b0310330e1fc0d Mon Sep 17 00:00:00 2001 From: Wojciech Trefon Date: Thu, 3 Oct 2024 15:51:43 +0200 Subject: [PATCH] Refactor schema resolver --- .../schemaevolution/ColumnInfos.java | 5 + .../schemaevolution/TableSchemaResolver.java | 121 +++++++++++++----- 2 files changed, 96 insertions(+), 30 deletions(-) diff --git a/src/main/java/com/snowflake/kafka/connector/internal/streaming/schemaevolution/ColumnInfos.java b/src/main/java/com/snowflake/kafka/connector/internal/streaming/schemaevolution/ColumnInfos.java index 1ee8b82bb..430fda2ec 100644 --- a/src/main/java/com/snowflake/kafka/connector/internal/streaming/schemaevolution/ColumnInfos.java +++ b/src/main/java/com/snowflake/kafka/connector/internal/streaming/schemaevolution/ColumnInfos.java @@ -12,6 +12,11 @@ public ColumnInfos(String columnType, String comments) { this.comments = comments; } + public ColumnInfos(String columnType) { + this.columnType = columnType; + this.comments = null; + } + public String getColumnType() { return columnType; } diff --git a/src/main/java/com/snowflake/kafka/connector/internal/streaming/schemaevolution/TableSchemaResolver.java b/src/main/java/com/snowflake/kafka/connector/internal/streaming/schemaevolution/TableSchemaResolver.java index 2f8bb10a5..e8f3b1b9a 100644 --- a/src/main/java/com/snowflake/kafka/connector/internal/streaming/schemaevolution/TableSchemaResolver.java +++ b/src/main/java/com/snowflake/kafka/connector/internal/streaming/schemaevolution/TableSchemaResolver.java @@ -1,14 +1,16 @@ package com.snowflake.kafka.connector.internal.streaming.schemaevolution; +import com.google.common.collect.Maps; +import com.google.common.collect.Streams; import com.snowflake.kafka.connector.Utils; import com.snowflake.kafka.connector.internal.SnowflakeErrors; import com.snowflake.kafka.connector.records.RecordService; import java.util.HashMap; import java.util.HashSet; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; import net.snowflake.client.jdbc.internal.fasterxml.jackson.databind.JsonNode; import org.apache.kafka.connect.data.Field; import org.apache.kafka.connect.data.Schema; @@ -30,40 +32,71 @@ protected TableSchemaResolver(ColumnTypeMapper columnTypeMapper) { * With the list of columns, collect their data types from either the schema or the data itself * * @param record the sink record that contains the schema and actual data - * @param columnNames the names of the extra columns + * @param columnsToInclude the names of the columns to include in the schema * @return a Map object where the key is column name and value is ColumnInfos */ - public TableSchema resolveTableSchemaFromRecord(SinkRecord record, List columnNames) { - if (columnNames == null) { + public TableSchema resolveTableSchemaFromRecord( + SinkRecord record, List columnsToInclude) { + if (columnsToInclude == null) { return new TableSchema(new HashMap<>()); } - Map columnToType = new HashMap<>(); - Map schemaMap = getSchemaMapFromRecord(record); + JsonNode recordNode = RecordService.convertToJson(record.valueSchema(), record.value(), true); - Set columnNamesSet = new HashSet<>(columnNames); - - Iterator> fields = recordNode.fields(); - while (fields.hasNext()) { - Map.Entry field = fields.next(); - String colName = Utils.quoteNameIfNeeded(field.getKey()); - if (columnNamesSet.contains(colName)) { - ColumnInfos columnInfos; - if (schemaMap.isEmpty()) { - // No schema associated with the record, we will try to infer it based on the data - columnInfos = new ColumnInfos(inferDataTypeFromJsonObject(field.getValue()), null); - } else { - // Get from the schema - columnInfos = schemaMap.get(field.getKey()); - if (columnInfos == null) { - // only when the type of the value is unrecognizable for JAVA - throw SnowflakeErrors.ERROR_5022.getException( - "column: " + field.getKey() + " schemaMap: " + schemaMap); - } - } - columnToType.put(colName, columnInfos); - } + Set columnNamesSet = new HashSet<>(columnsToInclude); + + if (hasSchema(record)) { + return getTableSchemaFromRecordSchema(recordNode, columnNamesSet, record); + } else { + return getTableSchemaFromJson(recordNode, columnNamesSet); } - return new TableSchema(columnToType); + } + + private boolean hasSchema(SinkRecord record) { + return record.valueSchema() != null + && record.valueSchema().fields() != null + && !record.valueSchema().fields().isEmpty(); + } + + private TableSchema getTableSchemaFromRecordSchema( + JsonNode recordNode, Set columnNamesSet, SinkRecord record) { + Map schemaMap = getFullSchemaMapFromRecord(record); + Map columnsInferredFromSchema = + Streams.stream(recordNode.fields()) + .map(ColumnValuePair::of) + .filter(pair -> columnNamesSet.contains(pair.getQuotedColumnName())) + .peek( + field -> { + if (!schemaMap.containsKey(field.getColumnName())) { + // only when the type of the value is unrecognizable for JAVA + throw SnowflakeErrors.ERROR_5022.getException( + "column: " + field.getColumnName() + " schemaMap: " + schemaMap); + } + }) + .map( + field -> + Maps.immutableEntry( + Utils.quoteNameIfNeeded(field.getQuotedColumnName()), + schemaMap.get(field.getColumnName()))) + .collect( + Collectors.toMap( + Map.Entry::getKey, Map.Entry::getValue, (oldValue, newValue) -> newValue)); + return new TableSchema(columnsInferredFromSchema); + } + + private TableSchema getTableSchemaFromJson(JsonNode recordNode, Set columnNamesSet) { + Map columnsInferredFromJson = + Streams.stream(recordNode.fields()) + .map(ColumnValuePair::of) + .filter(pair -> columnNamesSet.contains(pair.getQuotedColumnName())) + .map( + pair -> + Maps.immutableEntry( + pair.getQuotedColumnName(), + new ColumnInfos(inferDataTypeFromJsonObject(pair.getJsonNode())))) + .collect( + Collectors.toMap( + Map.Entry::getKey, Map.Entry::getValue, (oldValue, newValue) -> newValue)); + return new TableSchema(columnsInferredFromJson); } /** @@ -72,7 +105,7 @@ public TableSchema resolveTableSchemaFromRecord(SinkRecord record, List * @param record the sink record that contains the schema and actual data * @return a Map object where the key is column name and value is ColumnInfos */ - private Map getSchemaMapFromRecord(SinkRecord record) { + private Map getFullSchemaMapFromRecord(SinkRecord record) { Map schemaMap = new HashMap<>(); Schema schema = record.valueSchema(); if (schema != null && schema.fields() != null) { @@ -104,4 +137,32 @@ private String inferDataTypeFromJsonObject(JsonNode value) { // Passing null to schemaName when there is no schema information return columnTypeMapper.mapToColumnType(schemaType); } + + private static class ColumnValuePair { + private final String columnName; + private final String quotedColumnName; + private final JsonNode jsonNode; + + public static ColumnValuePair of(Map.Entry field) { + return new ColumnValuePair(field.getKey(), field.getValue()); + } + + private ColumnValuePair(String columnName, JsonNode jsonNode) { + this.columnName = columnName; + this.quotedColumnName = Utils.quoteNameIfNeeded(columnName); + this.jsonNode = jsonNode; + } + + public String getColumnName() { + return columnName; + } + + public String getQuotedColumnName() { + return quotedColumnName; + } + + public JsonNode getJsonNode() { + return jsonNode; + } + } }