diff --git a/flink-doris-connector/src/main/java/org/apache/doris/flink/deserialization/converter/DorisRowConverter.java b/flink-doris-connector/src/main/java/org/apache/doris/flink/deserialization/converter/DorisRowConverter.java index 6fa3be992..ebc0ff664 100644 --- a/flink-doris-connector/src/main/java/org/apache/doris/flink/deserialization/converter/DorisRowConverter.java +++ b/flink-doris-connector/src/main/java/org/apache/doris/flink/deserialization/converter/DorisRowConverter.java @@ -220,7 +220,9 @@ protected DeserializationConverter createInternalConverter(LogicalType type) { case ARRAY: return val -> convertArrayData(((List) val).toArray(), type); case ROW: + return val -> convertRowData((Map) val, type); case MAP: + return val -> convertMapData((Map) val, type); case MULTISET: case RAW: default: @@ -298,6 +300,33 @@ private ArrayData convertArrayData(Object[] array, LogicalType type) { return arrayData; } + private MapData convertMapData(Map map, LogicalType type){ + MapType mapType = (MapType) type; + DeserializationConverter keyConverter = createNullableInternalConverter(mapType.getKeyType()); + DeserializationConverter valueConverter = createNullableInternalConverter(mapType.getValueType()); + Map result = new HashMap<>(); + for(Map.Entry entry : map.entrySet()){ + Object key = keyConverter.deserialize(entry.getKey()); + Object value = valueConverter.deserialize(entry.getValue()); + result.put(key, value); + } + GenericMapData mapData = new GenericMapData(result); + return mapData; + } + + private RowData convertRowData(Map row, LogicalType type) { + RowType rowType = (RowType) type; + GenericRowData rowData = new GenericRowData(row.size()); + int index = 0; + for(Map.Entry entry : row.entrySet()){ + DeserializationConverter converter = createNullableInternalConverter(rowType.getTypeAt(index)); + Object value = converter.deserialize(entry.getValue()); + rowData.setField(index, value); + index++; + } + return rowData; + } + private List convertArrayData(ArrayData array, LogicalType type){ if(array instanceof GenericArrayData){ return Arrays.asList(((GenericArrayData)array).toObjectArray()); diff --git a/flink-doris-connector/src/main/java/org/apache/doris/flink/serialization/RowBatch.java b/flink-doris-connector/src/main/java/org/apache/doris/flink/serialization/RowBatch.java index de63a6ec1..ad8bb722f 100644 --- a/flink-doris-connector/src/main/java/org/apache/doris/flink/serialization/RowBatch.java +++ b/flink-doris-connector/src/main/java/org/apache/doris/flink/serialization/RowBatch.java @@ -32,6 +32,9 @@ import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.complex.impl.UnionMapReader; import org.apache.arrow.vector.ipc.ArrowStreamReader; import org.apache.arrow.vector.types.Types; import org.apache.doris.flink.exception.DorisException; @@ -50,7 +53,9 @@ import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.NoSuchElementException; /** @@ -329,6 +334,31 @@ private boolean doConvert(int col, //todo: when the subtype of array is date, conversion is required addValueToRow(rowIndex, listValue); break; + case "MAP": + if (!minorType.equals(Types.MinorType.MAP)) return false; + MapVector mapVector = (MapVector) fieldVector; + UnionMapReader reader = mapVector.getReader(); + if (mapVector.isNull(rowIndex)) { + addValueToRow(rowIndex, null); + break; + } + reader.setPosition(rowIndex); + Map mapValue = new HashMap<>(); + while (reader.next()) { + mapValue.put(reader.key().readObject().toString(), reader.value().readObject()); + } + addValueToRow(rowIndex, mapValue); + break; + case "STRUCT": + if (!minorType.equals(Types.MinorType.STRUCT)) return false; + StructVector structVector = (StructVector) fieldVector; + if (structVector.isNull(rowIndex)) { + addValueToRow(rowIndex, null); + break; + } + Map structValue = structVector.getObject(rowIndex); + addValueToRow(rowIndex, structValue); + break; default: String errMsg = "Unsupported type " + schema.get(col).getType(); logger.error(errMsg); diff --git a/flink-doris-connector/src/test/java/org/apache/doris/flink/serialization/TestRowBatch.java b/flink-doris-connector/src/test/java/org/apache/doris/flink/serialization/TestRowBatch.java index e2ee8272d..47071f53c 100644 --- a/flink-doris-connector/src/test/java/org/apache/doris/flink/serialization/TestRowBatch.java +++ b/flink-doris-connector/src/test/java/org/apache/doris/flink/serialization/TestRowBatch.java @@ -17,6 +17,7 @@ package org.apache.doris.flink.serialization; +import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVector; @@ -30,17 +31,25 @@ import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.complex.impl.NullableStructWriter; +import org.apache.arrow.vector.complex.impl.UnionMapWriter; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.ArrowStreamWriter; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.util.Text; +import org.apache.doris.flink.exception.DorisException; import org.apache.doris.flink.rest.RestService; import org.apache.doris.flink.rest.models.Schema; import org.apache.doris.sdk.thrift.TScanBatchResult; import org.apache.doris.sdk.thrift.TStatus; import org.apache.doris.sdk.thrift.TStatusCode; +import org.apache.flink.shaded.guava30.com.google.common.collect.ImmutableList; +import org.apache.flink.shaded.guava30.com.google.common.collect.ImmutableMap; import org.apache.flink.table.data.DecimalData; import org.junit.Assert; import org.junit.Rule; @@ -50,7 +59,9 @@ import org.slf4j.LoggerFactory; import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; import java.time.LocalDate; import java.time.LocalDateTime; import java.util.ArrayList; @@ -441,4 +452,137 @@ public void testDecimalV2() throws Exception { thrown.expectMessage(startsWith("Get row offset:")); rowBatch.next(); } + + @Test + public void testMap() throws IOException, DorisException { + + ImmutableList mapChildren = ImmutableList.of( + new Field("child", new FieldType(false, new ArrowType.Struct(), null), + ImmutableList.of( + new Field("key", new FieldType(false, new ArrowType.Utf8(), null), null), + new Field("value", new FieldType(false, new ArrowType.Int(32, true), null), + null) + ) + )); + + ImmutableList fields = ImmutableList.of( + new Field("col_map", new FieldType(false, new ArrowType.Map(false), null), + mapChildren) + ); + + RootAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + VectorSchemaRoot root = VectorSchemaRoot.create( + new org.apache.arrow.vector.types.pojo.Schema(fields, null), allocator); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter( + root, + new DictionaryProvider.MapDictionaryProvider(), + outputStream); + + arrowStreamWriter.start(); + root.setRowCount(3); + + MapVector mapVector = (MapVector) root.getVector("col_map"); + mapVector.allocateNew(); + UnionMapWriter mapWriter = mapVector.getWriter(); + for (int i = 0; i < 3; i++) { + mapWriter.setPosition(i); + mapWriter.startMap(); + mapWriter.startEntry(); + String key = "k" + (i + 1); + byte[] bytes = key.getBytes(StandardCharsets.UTF_8); + ArrowBuf buffer = allocator.buffer(bytes.length); + buffer.setBytes(0, bytes); + mapWriter.key().varChar().writeVarChar(0, bytes.length, buffer); + buffer.close(); + mapWriter.value().integer().writeInt(i); + mapWriter.endEntry(); + mapWriter.endMap(); + } + mapWriter.setValueCount(3); + + arrowStreamWriter.writeBatch(); + + arrowStreamWriter.end(); + arrowStreamWriter.close(); + + TStatus status = new TStatus(); + status.setStatusCode(TStatusCode.OK); + TScanBatchResult scanBatchResult = new TScanBatchResult(); + scanBatchResult.setStatus(status); + scanBatchResult.setEos(false); + scanBatchResult.setRows(outputStream.toByteArray()); + + String schemaStr = "{\"properties\":[{\"type\":\"MAP\",\"name\":\"col_map\",\"comment\":\"\"}" + + "], \"status\":200}"; + + + Schema schema = RestService.parseSchema(schemaStr, logger); + + RowBatch rowBatch = new RowBatch(scanBatchResult, schema).readArrow(); + Assert.assertTrue(rowBatch.hasNext()); + Assert.assertTrue(ImmutableMap.of("k1", 0).equals(rowBatch.next().get(0))); + Assert.assertTrue(rowBatch.hasNext()); + Assert.assertTrue(ImmutableMap.of("k2", 1).equals(rowBatch.next().get(0))); + Assert.assertTrue(rowBatch.hasNext()); + Assert.assertTrue(ImmutableMap.of("k3", 2).equals(rowBatch.next().get(0))); + Assert.assertFalse(rowBatch.hasNext()); + + } + + @Test + public void testStruct() throws IOException, DorisException { + + ImmutableList fields = ImmutableList.of( + new Field("col_struct", new FieldType(false, new ArrowType.Struct(), null), + ImmutableList.of(new Field("a", new FieldType(false, new ArrowType.Utf8(), null), null), + new Field("b", new FieldType(false, new ArrowType.Int(32, true), null), null)) + )); + + RootAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + VectorSchemaRoot root = VectorSchemaRoot.create( + new org.apache.arrow.vector.types.pojo.Schema(fields, null), allocator); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter( + root, + new DictionaryProvider.MapDictionaryProvider(), + outputStream); + + arrowStreamWriter.start(); + root.setRowCount(3); + + StructVector structVector = (StructVector) root.getVector("col_struct"); + structVector.allocateNew(); + NullableStructWriter writer = structVector.getWriter(); + writer.setPosition(0); + writer.start(); + byte[] bytes = "a1".getBytes(StandardCharsets.UTF_8); + ArrowBuf buffer = allocator.buffer(bytes.length); + buffer.setBytes(0, bytes); + writer.varChar("a").writeVarChar(0, bytes.length, buffer); + buffer.close(); + writer.integer("b").writeInt(1); + writer.end(); + writer.setValueCount(1); + + arrowStreamWriter.writeBatch(); + + arrowStreamWriter.end(); + arrowStreamWriter.close(); + + TStatus status = new TStatus(); + status.setStatusCode(TStatusCode.OK); + TScanBatchResult scanBatchResult = new TScanBatchResult(); + scanBatchResult.setStatus(status); + scanBatchResult.setEos(false); + scanBatchResult.setRows(outputStream.toByteArray()); + + String schemaStr = "{\"properties\":[{\"type\":\"STRUCT\",\"name\":\"col_struct\",\"comment\":\"\"}" + + "], \"status\":200}"; + Schema schema = RestService.parseSchema(schemaStr, logger); + + RowBatch rowBatch = new RowBatch(scanBatchResult, schema).readArrow(); + Assert.assertTrue(rowBatch.hasNext()); + Assert.assertTrue(ImmutableMap.of("a", new Text("a1"),"b",1).equals(rowBatch.next().get(0))); + } }