diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java index 85e9a6f4..8dbc4bfb 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java @@ -17,14 +17,11 @@ package org.apache.doris.spark.serialization; -import org.apache.doris.sdk.thrift.TScanBatchResult; -import org.apache.doris.spark.exception.DorisException; -import org.apache.doris.spark.rest.models.Schema; - import com.google.common.base.Preconditions; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.FixedSizeBinaryVector; @@ -32,6 +29,7 @@ import org.apache.arrow.vector.Float8Vector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeStampMicroVector; import org.apache.arrow.vector.TinyIntVector; import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; @@ -43,6 +41,9 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader; import org.apache.arrow.vector.types.Types; import org.apache.commons.lang3.ArrayUtils; +import org.apache.doris.sdk.thrift.TScanBatchResult; +import org.apache.doris.spark.exception.DorisException; +import org.apache.doris.spark.rest.models.Schema; import org.apache.spark.sql.types.Decimal; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -54,7 +55,13 @@ import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.sql.Date; +import java.time.Instant; import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.ZoneId; +import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeFormatterBuilder; +import java.time.temporal.ChronoField; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -68,6 +75,11 @@ public class RowBatch { private static final Logger logger = LoggerFactory.getLogger(RowBatch.class); + private static final DateTimeFormatter DATE_TIME_FORMATTER = new DateTimeFormatterBuilder() + .appendPattern("yyyy-MM-dd HH:mm:ss") + .appendFraction(ChronoField.MICRO_OF_SECOND, 0, 6, true) + .toFormatter(); + public static class Row { private final List cols; @@ -301,21 +313,68 @@ public void convertArrowToRowBatch() throws DorisException { break; case "DATE": case "DATEV2": - Preconditions.checkArgument(mt.equals(Types.MinorType.VARCHAR), - typeMismatchMessage(currentType, mt)); - VarCharVector date = (VarCharVector) curFieldVector; - for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { - if (date.isNull(rowIndex)) { - addValueToRow(rowIndex, null); - continue; + Preconditions.checkArgument(mt.equals(Types.MinorType.VARCHAR) + || mt.equals(Types.MinorType.DATEDAY), typeMismatchMessage(currentType, mt)); + if (mt.equals(Types.MinorType.VARCHAR)) { + VarCharVector date = (VarCharVector) curFieldVector; + for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { + if (date.isNull(rowIndex)) { + addValueToRow(rowIndex, null); + continue; + } + String stringValue = new String(date.get(rowIndex)); + LocalDate localDate = LocalDate.parse(stringValue); + addValueToRow(rowIndex, Date.valueOf(localDate)); + } + } else { + DateDayVector date = (DateDayVector) curFieldVector; + for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { + if (date.isNull(rowIndex)) { + addValueToRow(rowIndex, null); + continue; + } + LocalDate localDate = LocalDate.ofEpochDay(date.get(rowIndex)); + addValueToRow(rowIndex, Date.valueOf(localDate)); } - String stringValue = new String(date.get(rowIndex)); - LocalDate localDate = LocalDate.parse(stringValue); - addValueToRow(rowIndex, Date.valueOf(localDate)); } break; case "DATETIME": case "DATETIMEV2": + Preconditions.checkArgument(mt.equals(Types.MinorType.VARCHAR) + || mt.equals(Types.MinorType.TIMESTAMPMICRO), + typeMismatchMessage(currentType, mt)); + if (mt.equals(Types.MinorType.VARCHAR)) { + VarCharVector varCharVector = (VarCharVector) curFieldVector; + for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { + if (varCharVector.isNull(rowIndex)) { + addValueToRow(rowIndex, null); + continue; + } + String value = new String(varCharVector.get(rowIndex), StandardCharsets.UTF_8); + addValueToRow(rowIndex, value); + } + } else { + TimeStampMicroVector vector = (TimeStampMicroVector) curFieldVector; + for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { + if (vector.isNull(rowIndex)) { + addValueToRow(rowIndex, null); + continue; + } + long time = vector.get(rowIndex); + Instant instant; + if (time / 10000000000L == 0) { // datetime(0) + instant = Instant.ofEpochSecond(time); + } else if (time / 10000000000000L == 0) { // datetime(3) + instant = Instant.ofEpochMilli(time); + } else { // datetime(6) + instant = Instant.ofEpochSecond(time / 1000000, time % 1000000 * 1000); + } + LocalDateTime dateTime = LocalDateTime.ofInstant(instant, ZoneId.systemDefault()); + String formatted = DATE_TIME_FORMATTER.format(dateTime); + addValueToRow(rowIndex, formatted); + } + } + break; case "CHAR": case "VARCHAR": case "STRING": diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala index 0156d37e..fc01d6b7 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala @@ -21,7 +21,7 @@ import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings} import org.apache.doris.spark.exception.DorisException import org.apache.doris.spark.jdbc.JdbcUtils -import org.apache.doris.spark.load.{CommitMessage, StreamLoader} +import org.apache.doris.spark.load.CommitMessage import org.apache.doris.spark.sql.DorisSourceProvider.SHORT_NAME import org.apache.doris.spark.writer.DorisWriter import org.apache.spark.SparkConf diff --git a/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java b/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java index cb7e0b8f..1cf4136e 100644 --- a/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java +++ b/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java @@ -17,6 +17,10 @@ package org.apache.doris.spark.serialization; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.TimeUnit; import org.apache.doris.sdk.thrift.TScanBatchResult; import org.apache.doris.sdk.thrift.TStatus; import org.apache.doris.sdk.thrift.TStatusCode; @@ -67,9 +71,13 @@ import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.sql.Date; +import java.time.LocalDateTime; +import java.time.ZoneId; +import java.time.ZoneOffset; import java.util.Arrays; import java.util.List; import java.util.NoSuchElementException; +import java.util.TimeZone; import static org.hamcrest.core.StringStartsWith.startsWith; @@ -458,6 +466,7 @@ public void testDate() throws DorisException, IOException { ImmutableList.Builder childrenBuilder = ImmutableList.builder(); childrenBuilder.add(new Field("k1", FieldType.nullable(new ArrowType.Utf8()), null)); childrenBuilder.add(new Field("k2", FieldType.nullable(new ArrowType.Utf8()), null)); + childrenBuilder.add(new Field("k3", FieldType.nullable(new ArrowType.Date(DateUnit.DAY)), null)); VectorSchemaRoot root = VectorSchemaRoot.create( new org.apache.arrow.vector.types.pojo.Schema(childrenBuilder.build(), null), @@ -490,6 +499,14 @@ public void testDate() throws DorisException, IOException { dateV2Vector.setSafe(0, "2023-08-10".getBytes()); vector.setValueCount(1); + vector = root.getVector("k3"); + DateDayVector dateNewVector = (DateDayVector)vector; + dateNewVector.setInitialCapacity(1); + dateNewVector.allocateNew(); + dateNewVector.setIndexDefined(0); + dateNewVector.setSafe(0, 19802); + vector.setValueCount(1); + arrowStreamWriter.writeBatch(); arrowStreamWriter.end(); @@ -505,7 +522,8 @@ public void testDate() throws DorisException, IOException { String schemaStr = "{\"properties\":[" + "{\"type\":\"DATE\",\"name\":\"k1\",\"comment\":\"\"}, " + - "{\"type\":\"DATEV2\",\"name\":\"k2\",\"comment\":\"\"}" + + "{\"type\":\"DATEV2\",\"name\":\"k2\",\"comment\":\"\"}, " + + "{\"type\":\"DATEV2\",\"name\":\"k3\",\"comment\":\"\"}" + "], \"status\":200}"; Schema schema = RestService.parseSchema(schemaStr, logger); @@ -516,6 +534,7 @@ public void testDate() throws DorisException, IOException { List actualRow0 = rowBatch.next(); Assert.assertEquals(Date.valueOf("2023-08-09"), actualRow0.get(0)); Assert.assertEquals(Date.valueOf("2023-08-10"), actualRow0.get(1)); + Assert.assertEquals(Date.valueOf("2024-03-20"), actualRow0.get(2)); Assert.assertFalse(rowBatch.hasNext()); thrown.expect(NoSuchElementException.class); @@ -737,4 +756,98 @@ public void testStruct() throws IOException, DorisException { } + @Test + public void testDateTime() throws IOException, DorisException { + + ImmutableList.Builder childrenBuilder = ImmutableList.builder(); + childrenBuilder.add(new Field("k1", FieldType.nullable(new ArrowType.Utf8()), null)); + childrenBuilder.add(new Field("k2", FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MICROSECOND, + null)), null)); + + VectorSchemaRoot root = VectorSchemaRoot.create( + new org.apache.arrow.vector.types.pojo.Schema(childrenBuilder.build(), null), + new RootAllocator(Integer.MAX_VALUE)); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter( + root, + new DictionaryProvider.MapDictionaryProvider(), + outputStream); + + arrowStreamWriter.start(); + root.setRowCount(3); + + FieldVector vector = root.getVector("k1"); + VarCharVector datetimeVector = (VarCharVector)vector; + datetimeVector.setInitialCapacity(3); + datetimeVector.allocateNew(); + datetimeVector.setIndexDefined(0); + datetimeVector.setValueLengthSafe(0, 20); + datetimeVector.setSafe(0, "2024-03-20 00:00:00".getBytes()); + datetimeVector.setIndexDefined(1); + datetimeVector.setValueLengthSafe(1, 20); + datetimeVector.setSafe(1, "2024-03-20 00:00:01".getBytes()); + datetimeVector.setIndexDefined(2); + datetimeVector.setValueLengthSafe(2, 20); + datetimeVector.setSafe(2, "2024-03-20 00:00:02".getBytes()); + vector.setValueCount(3); + + LocalDateTime localDateTime = LocalDateTime.of(2024, 3, 20, + 0, 0, 0, 123456000); + long second = localDateTime.atZone(ZoneId.systemDefault()).toEpochSecond(); + int nano = localDateTime.getNano(); + + vector = root.getVector("k2"); + TimeStampMicroVector datetimeV2Vector = (TimeStampMicroVector) vector; + datetimeV2Vector.setInitialCapacity(3); + datetimeV2Vector.allocateNew(); + datetimeV2Vector.setIndexDefined(0); + datetimeV2Vector.setSafe(0, second); + datetimeV2Vector.setIndexDefined(1); + datetimeV2Vector.setSafe(1, second * 1000 + nano / 1000000); + datetimeV2Vector.setIndexDefined(2); + datetimeV2Vector.setSafe(2, second * 1000000 + nano / 1000); + vector.setValueCount(3); + + arrowStreamWriter.writeBatch(); + + arrowStreamWriter.end(); + arrowStreamWriter.close(); + + TStatus status = new TStatus(); + status.setStatusCode(TStatusCode.OK); + TScanBatchResult scanBatchResult = new TScanBatchResult(); + scanBatchResult.setStatus(status); + scanBatchResult.setEos(false); + scanBatchResult.setRows(outputStream.toByteArray()); + + + String schemaStr = "{\"properties\":[" + + "{\"type\":\"DATETIME\",\"name\":\"k1\",\"comment\":\"\"}, " + + "{\"type\":\"DATETIMEV2\",\"name\":\"k2\",\"comment\":\"\"}" + + "], \"status\":200}"; + + Schema schema = RestService.parseSchema(schemaStr, logger); + + RowBatch rowBatch = new RowBatch(scanBatchResult, schema); + + Assert.assertTrue(rowBatch.hasNext()); + List actualRow0 = rowBatch.next(); + Assert.assertEquals("2024-03-20 00:00:00", actualRow0.get(0)); + Assert.assertEquals("2024-03-20 00:00:00", actualRow0.get(1)); + + List actualRow1 = rowBatch.next(); + Assert.assertEquals("2024-03-20 00:00:01", actualRow1.get(0)); + Assert.assertEquals("2024-03-20 00:00:00.123", actualRow1.get(1)); + + List actualRow2 = rowBatch.next(); + Assert.assertEquals("2024-03-20 00:00:02", actualRow2.get(0)); + Assert.assertEquals("2024-03-20 00:00:00.123456", actualRow2.get(1)); + + Assert.assertFalse(rowBatch.hasNext()); + thrown.expect(NoSuchElementException.class); + thrown.expectMessage(startsWith("Get row offset:")); + rowBatch.next(); + + } + }