diff --git a/src/main/java/io/confluent/connect/hdfs/orc/OrcRecordWriterProvider.java b/src/main/java/io/confluent/connect/hdfs/orc/OrcRecordWriterProvider.java index aba47b9e5..3499986a5 100644 --- a/src/main/java/io/confluent/connect/hdfs/orc/OrcRecordWriterProvider.java +++ b/src/main/java/io/confluent/connect/hdfs/orc/OrcRecordWriterProvider.java @@ -16,9 +16,9 @@ package io.confluent.connect.hdfs.orc; import io.confluent.connect.hdfs.HdfsSinkConnectorConfig; +import io.confluent.connect.hdfs.schema.HiveSchemaConverterWithLogicalTypes; import io.confluent.connect.storage.format.RecordWriter; import io.confluent.connect.storage.format.RecordWriterProvider; -import io.confluent.connect.storage.hive.HiveSchemaConverter; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.ql.io.orc.OrcFile; import org.apache.hadoop.hive.ql.io.orc.OrcStruct; @@ -70,7 +70,7 @@ public void preFooterWrite(OrcFile.WriterContext writerContext) { } }; - typeInfo = HiveSchemaConverter.convert(schema); + typeInfo = HiveSchemaConverterWithLogicalTypes.convert(schema); ObjectInspector objectInspector = OrcStruct.createObjectInspector(typeInfo); log.info("Opening ORC record writer for: {}", filename); @@ -90,7 +90,10 @@ public void preFooterWrite(OrcFile.WriterContext writerContext) { ); Struct struct = (Struct) record.value(); - OrcStruct row = OrcUtil.createOrcStruct(typeInfo, OrcUtil.convertStruct(struct)); + OrcStruct row = OrcUtil.createOrcStruct( + typeInfo, + OrcUtil.convertStruct(typeInfo, struct) + ); writer.addRow(row); } else { diff --git a/src/main/java/io/confluent/connect/hdfs/orc/OrcUtil.java b/src/main/java/io/confluent/connect/hdfs/orc/OrcUtil.java index 25d783cb9..70443e5d4 100644 --- a/src/main/java/io/confluent/connect/hdfs/orc/OrcUtil.java +++ b/src/main/java/io/confluent/connect/hdfs/orc/OrcUtil.java @@ -28,17 +28,21 @@ import static org.apache.kafka.connect.data.Schema.Type.STRING; import static org.apache.kafka.connect.data.Schema.Type.STRUCT; +import java.math.BigDecimal; import java.util.HashMap; import java.util.Map; -import java.util.function.BiFunction; + +import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.io.orc.OrcStruct; import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.DateWritable; import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.io.ShortWritable; import org.apache.hadoop.hive.serde2.io.TimestampWritable; import org.apache.hadoop.hive.serde2.objectinspector.SettableStructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.ArrayPrimitiveWritable; import org.apache.hadoop.io.BooleanWritable; @@ -50,6 +54,7 @@ import org.apache.hadoop.io.ObjectWritable; import org.apache.hadoop.io.Text; import org.apache.kafka.connect.data.Date; +import org.apache.kafka.connect.data.Decimal; import org.apache.kafka.connect.data.Field; import org.apache.kafka.connect.data.Schema; import org.apache.kafka.connect.data.Schema.Type; @@ -59,10 +64,12 @@ import java.util.LinkedList; import java.util.List; +import java.util.function.BiFunction; public final class OrcUtil { - private static Map> CONVERSION_MAP = new HashMap<>(); + private static final Map> CONVERSION_MAP = + new HashMap<>(); static { CONVERSION_MAP.put(ARRAY, OrcUtil::convertArray); @@ -76,7 +83,6 @@ public final class OrcUtil { CONVERSION_MAP.put(INT64, OrcUtil::convertInt64); CONVERSION_MAP.put(MAP, OrcUtil::convertMap); CONVERSION_MAP.put(STRING, OrcUtil::convertString); - CONVERSION_MAP.put(STRUCT, OrcUtil::convertStruct); } /** @@ -87,8 +93,8 @@ public final class OrcUtil { * @return the struct object */ @SuppressWarnings("unchecked") - public static OrcStruct createOrcStruct(TypeInfo typeInfo, Object... objs) { - SettableStructObjectInspector oi = (SettableStructObjectInspector) + public static OrcStruct createOrcStruct(TypeInfo typeInfo, Object[] objs) { + SettableStructObjectInspector oi = (SettableStructObjectInspector) OrcStruct.createObjectInspector(typeInfo); List fields = (List) oi.getAllStructFieldRefs(); @@ -107,22 +113,31 @@ public static OrcStruct createOrcStruct(TypeInfo typeInfo, Object... objs) { * @param struct the struct to convert * @return the struct as a writable array */ - public static Object[] convertStruct(Struct struct) { + public static Object[] convertStruct(TypeInfo typeInfo, Struct struct) { List data = new LinkedList<>(); for (Field field : struct.schema().fields()) { if (struct.get(field) == null) { data.add(null); } else { Schema.Type schemaType = field.schema().type(); - data.add(CONVERSION_MAP.get(schemaType).apply(struct, field)); + if (STRUCT.equals(schemaType)) { + data.add(convertStruct(typeInfo, struct, field)); + } else { + data.add(CONVERSION_MAP.get(schemaType).apply(struct, field)); + } } } return data.toArray(); } - private static Object convertStruct(Struct struct, Field field) { - return convertStruct(struct.getStruct(field.name())); + private static Object convertStruct(TypeInfo typeInfo, Struct struct, Field field) { + TypeInfo fieldTypeInfo = ((StructTypeInfo) typeInfo).getStructFieldTypeInfo(field.name()); + + return createOrcStruct( + fieldTypeInfo, + convertStruct(fieldTypeInfo, struct.getStruct(field.name())) + ); } private static Object convertArray(Struct struct, Field field) { @@ -134,6 +149,12 @@ private static Object convertBoolean(Struct struct, Field field) { } private static Object convertBytes(Struct struct, Field field) { + + if (Decimal.LOGICAL_NAME.equals(field.schema().name())) { + BigDecimal bigDecimal = (BigDecimal) struct.get(field.name()); + return new HiveDecimalWritable(HiveDecimal.create(bigDecimal)); + } + return new BytesWritable(struct.getBytes(field.name())); } @@ -162,7 +183,7 @@ private static Object convertInt32(Struct struct, Field field) { if (Time.LOGICAL_NAME.equals(field.schema().name())) { java.util.Date date = (java.util.Date) struct.get(field); - return new TimestampWritable(new java.sql.Timestamp(date.getTime())); + return new IntWritable((int) date.getTime()); } return new IntWritable(struct.getInt32(field.name())); diff --git a/src/main/java/io/confluent/connect/hdfs/schema/HiveSchemaConverterWithLogicalTypes.java b/src/main/java/io/confluent/connect/hdfs/schema/HiveSchemaConverterWithLogicalTypes.java new file mode 100644 index 000000000..c7eacfe7d --- /dev/null +++ b/src/main/java/io/confluent/connect/hdfs/schema/HiveSchemaConverterWithLogicalTypes.java @@ -0,0 +1,88 @@ +/* + * Copyright 2020 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.connect.hdfs.schema; + +import io.confluent.connect.storage.hive.HiveSchemaConverter; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.kafka.connect.data.Date; +import org.apache.kafka.connect.data.Field; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.Timestamp; + +import java.util.ArrayList; +import java.util.List; + +public class HiveSchemaConverterWithLogicalTypes { + + public static TypeInfo convert(Schema schema) { + // TODO: throw an error on recursive types + switch (schema.type()) { + case STRUCT: + return convertStruct(schema); + case ARRAY: + return convertArray(schema); + case MAP: + return convertMap(schema); + default: + return convertPrimitive(schema); + } + } + + public static TypeInfo convertStruct(Schema schema) { + final List fields = schema.fields(); + final List names = new ArrayList<>(fields.size()); + final List types = new ArrayList<>(fields.size()); + for (Field field : fields) { + names.add(field.name()); + types.add(convert(field.schema())); + } + return TypeInfoFactory.getStructTypeInfo(names, types); + } + + public static TypeInfo convertArray(Schema schema) { + return TypeInfoFactory.getListTypeInfo(convert(schema.valueSchema())); + } + + public static TypeInfo convertMap(Schema schema) { + return TypeInfoFactory.getMapTypeInfo( + convert(schema.keySchema()), + convert(schema.valueSchema()) + ); + } + + public static TypeInfo convertPrimitive(Schema schema) { + if (schema.name() != null) { + switch (schema.name()) { + case Date.LOGICAL_NAME: + return TypeInfoFactory.dateTypeInfo; + case Timestamp.LOGICAL_NAME: + return TypeInfoFactory.timestampTypeInfo; + // NOTE: We currently leave TIME values as INT32 (the default). + // Converting to a STRING would be ok too. + // Sadly, writing as INTERVAL is unsupported in the kafka-connect library. + // See: org.apache.hadoop.hive.ql.io.orc.WriterImpl - INTERVAL is missing + //case Time.LOGICAL_NAME: + // return TypeInfoFactory.intervalDayTimeTypeInfo; + default: + break; + } + } + + // HiveSchemaConverter converts primitives just fine, just not all logical-types. + return HiveSchemaConverter.convertPrimitiveMaybeLogical(schema); + } +} \ No newline at end of file diff --git a/src/test/java/io/confluent/connect/hdfs/orc/DataWriterOrcTest.java b/src/test/java/io/confluent/connect/hdfs/orc/DataWriterOrcTest.java index 778a0b9d8..a65115949 100644 --- a/src/test/java/io/confluent/connect/hdfs/orc/DataWriterOrcTest.java +++ b/src/test/java/io/confluent/connect/hdfs/orc/DataWriterOrcTest.java @@ -19,7 +19,7 @@ import io.confluent.connect.hdfs.DataWriter; import io.confluent.connect.hdfs.HdfsSinkConnectorConfig; import io.confluent.connect.hdfs.TestWithMiniDFSCluster; -import io.confluent.connect.storage.hive.HiveSchemaConverter; +import io.confluent.connect.hdfs.schema.HiveSchemaConverterWithLogicalTypes; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.kafka.connect.data.Field; import org.apache.kafka.connect.data.Schema; @@ -80,7 +80,7 @@ protected void verifyContents(List expectedRecords, int startIndex, expectedRecords.get(startIndex++).value(), expectedSchema); - TypeInfo typeInfo = HiveSchemaConverter.convert(expectedSchema); + TypeInfo typeInfo = HiveSchemaConverterWithLogicalTypes.convert(expectedSchema); ArrayList objs = new ArrayList<>(); for (Field field : expectedSchema.fields()) {