From 1fb186741c3236d647db8ff28ac1b448b931ba79 Mon Sep 17 00:00:00 2001 From: yuzelin Date: Mon, 5 Aug 2024 17:12:45 +0800 Subject: [PATCH] fix --- .../parquet/reader/NestedColumnReader.java | 93 ++- .../parquet/reader/NestedPositionUtil.java | 61 +- .../reader/NestedPrimitiveColumnReader.java | 34 +- .../parquet/ParquetColumnVectorTest.java | 782 ++++++++++++++++++ 4 files changed, 882 insertions(+), 88 deletions(-) create mode 100644 paimon-format/src/test/java/org/apache/paimon/format/parquet/ParquetColumnVectorTest.java diff --git a/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/NestedColumnReader.java b/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/NestedColumnReader.java index 165527adc688..c89c77603dac 100644 --- a/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/NestedColumnReader.java +++ b/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/NestedColumnReader.java @@ -26,7 +26,6 @@ import org.apache.paimon.data.columnar.writable.WritableColumnVector; import org.apache.paimon.format.parquet.position.CollectionPosition; import org.apache.paimon.format.parquet.position.LevelDelegation; -import org.apache.paimon.format.parquet.position.RowPosition; import org.apache.paimon.format.parquet.type.ParquetField; import org.apache.paimon.format.parquet.type.ParquetGroupField; import org.apache.paimon.format.parquet.type.ParquetPrimitiveField; @@ -41,6 +40,7 @@ import org.apache.parquet.column.page.PageReadStore; import java.io.IOException; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -86,7 +86,7 @@ public NestedColumnReader(boolean isUtcTimestamp, PageReadStore pages, ParquetFi @Override public void readToVector(int readNumber, WritableColumnVector vector) throws IOException { - readData(field, readNumber, vector, false, false); + readData(field, readNumber, vector, false, false, false); } private Pair readData( @@ -94,17 +94,18 @@ private Pair readData( int readNumber, ColumnVector vector, boolean inside, - boolean parentIsRowType) + boolean readRowField, + boolean readMapKey) throws IOException { if (field.getType() instanceof RowType) { return readRow((ParquetGroupField) field, readNumber, vector, inside); } else if (field.getType() instanceof MapType || field.getType() instanceof MultisetType) { - return readMap((ParquetGroupField) field, readNumber, vector, inside); + return readMap((ParquetGroupField) field, readNumber, vector, inside, readRowField); } else if (field.getType() instanceof ArrayType) { - return readArray((ParquetGroupField) field, readNumber, vector, inside); + return readArray((ParquetGroupField) field, readNumber, vector, inside, readRowField); } else { return readPrimitive( - (ParquetPrimitiveField) field, readNumber, vector, parentIsRowType); + (ParquetPrimitiveField) field, readNumber, vector, readRowField, readMapKey); } } @@ -112,45 +113,60 @@ private Pair readRow( ParquetGroupField field, int readNumber, ColumnVector vector, boolean inside) throws IOException { HeapRowVector heapRowVector = (HeapRowVector) vector; - LevelDelegation levelDelegation = null; + LevelDelegation longest = null; List children = field.getChildren(); WritableColumnVector[] childrenVectors = heapRowVector.getFields(); WritableColumnVector[] finalChildrenVectors = new WritableColumnVector[childrenVectors.length]; for (int i = 0; i < children.size(); i++) { Pair tuple = - readData(children.get(i), readNumber, childrenVectors[i], true, true); - levelDelegation = tuple.getLeft(); + readData(children.get(i), readNumber, childrenVectors[i], true, true, false); + LevelDelegation current = tuple.getLeft(); + if (longest == null) { + longest = current; + } else if (current.getDefinitionLevel().length > longest.getDefinitionLevel().length) { + longest = current; + } finalChildrenVectors[i] = tuple.getRight(); } - if (levelDelegation == null) { + if (longest == null) { throw new RuntimeException( String.format("Row field does not have any children: %s.", field)); } - RowPosition rowPosition = - NestedPositionUtil.calculateRowOffsets( - field, - levelDelegation.getDefinitionLevel(), - levelDelegation.getRepetitionLevel()); + int len = ((AbstractHeapVector) finalChildrenVectors[0]).getLen(); + boolean[] isNull = new boolean[len]; + Arrays.fill(isNull, true); + boolean hasNull = false; + for (int i = 0; i < len; i++) { + for (WritableColumnVector child : finalChildrenVectors) { + isNull[i] = isNull[i] && child.isNullAt(i); + } + if (isNull[i]) { + hasNull = true; + } + } // If row was inside the structure, then we need to renew the vector to reset the // capacity. if (inside) { - heapRowVector = - new HeapRowVector(rowPosition.getPositionsCount(), finalChildrenVectors); + heapRowVector = new HeapRowVector(len, finalChildrenVectors); } else { heapRowVector.setFields(finalChildrenVectors); } - if (rowPosition.getIsNull() != null) { - setFieldNullFlag(rowPosition.getIsNull(), heapRowVector); + if (hasNull) { + setFieldNullFlag(isNull, heapRowVector); } - return Pair.of(levelDelegation, heapRowVector); + return Pair.of(longest, heapRowVector); } private Pair readMap( - ParquetGroupField field, int readNumber, ColumnVector vector, boolean inside) + ParquetGroupField field, + int readNumber, + ColumnVector vector, + boolean inside, + boolean readRowField) throws IOException { HeapMapVector mapVector = (HeapMapVector) vector; mapVector.reset(); @@ -160,10 +176,21 @@ private Pair readMap( "Maps must have two type parameters, found %s", children.size()); Pair keyTuple = - readData(children.get(0), readNumber, mapVector.getKeyColumnVector(), true, false); + readData( + children.get(0), + readNumber, + mapVector.getKeyColumnVector(), + true, + false, + true); Pair valueTuple = readData( - children.get(1), readNumber, mapVector.getValueColumnVector(), true, false); + children.get(1), + readNumber, + mapVector.getValueColumnVector(), + true, + false, + false); LevelDelegation levelDelegation = keyTuple.getLeft(); @@ -171,7 +198,8 @@ private Pair readMap( NestedPositionUtil.calculateCollectionOffsets( field, levelDelegation.getDefinitionLevel(), - levelDelegation.getRepetitionLevel()); + levelDelegation.getRepetitionLevel(), + readRowField); // If map was inside the structure, then we need to renew the vector to reset the // capacity. @@ -197,7 +225,11 @@ private Pair readMap( } private Pair readArray( - ParquetGroupField field, int readNumber, ColumnVector vector, boolean inside) + ParquetGroupField field, + int readNumber, + ColumnVector vector, + boolean inside, + boolean readRowField) throws IOException { HeapArrayVector arrayVector = (HeapArrayVector) vector; arrayVector.reset(); @@ -207,14 +239,15 @@ private Pair readArray( "Arrays must have a single type parameter, found %s", children.size()); Pair tuple = - readData(children.get(0), readNumber, arrayVector.getChild(), true, false); + readData(children.get(0), readNumber, arrayVector.getChild(), true, false, false); LevelDelegation levelDelegation = tuple.getLeft(); CollectionPosition collectionPosition = NestedPositionUtil.calculateCollectionOffsets( field, levelDelegation.getDefinitionLevel(), - levelDelegation.getRepetitionLevel()); + levelDelegation.getRepetitionLevel(), + readRowField); // If array was inside the structure, then we need to renew the vector to reset the // capacity. @@ -236,7 +269,8 @@ private Pair readPrimitive( ParquetPrimitiveField field, int readNumber, ColumnVector vector, - boolean parentIsRowType) + boolean readRowField, + boolean readMapKey) throws IOException { ColumnDescriptor descriptor = field.getDescriptor(); NestedPrimitiveColumnReader reader = columnReaders.get(descriptor); @@ -248,7 +282,8 @@ private Pair readPrimitive( isUtcTimestamp, descriptor.getPrimitiveType(), field.getType(), - parentIsRowType); + readRowField, + readMapKey); columnReaders.put(descriptor, reader); } WritableColumnVector writableColumnVector = diff --git a/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/NestedPositionUtil.java b/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/NestedPositionUtil.java index 9757b94c7583..b43169a40b2c 100644 --- a/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/NestedPositionUtil.java +++ b/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/NestedPositionUtil.java @@ -19,7 +19,6 @@ package org.apache.paimon.format.parquet.reader; import org.apache.paimon.format.parquet.position.CollectionPosition; -import org.apache.paimon.format.parquet.position.RowPosition; import org.apache.paimon.format.parquet.type.ParquetField; import org.apache.paimon.utils.BooleanArrayList; import org.apache.paimon.utils.LongArrayList; @@ -29,50 +28,6 @@ /** Utils to calculate nested type position. */ public class NestedPositionUtil { - /** - * Calculate row offsets according to column's max repetition level, definition level, value's - * repetition level and definition level. Each row has three situation: - *
  • Row is not defined,because it's optional parent fields is null, this is decided by its - * parent's repetition level - *
  • Row is null - *
  • Row is defined and not empty. - * - * @param field field that contains the row column message include max repetition level and - * definition level. - * @param fieldRepetitionLevels int array with each value's repetition level. - * @param fieldDefinitionLevels int array with each value's definition level. - * @return {@link RowPosition} contains collections row count and isNull array. - */ - public static RowPosition calculateRowOffsets( - ParquetField field, int[] fieldDefinitionLevels, int[] fieldRepetitionLevels) { - int rowDefinitionLevel = field.getDefinitionLevel(); - int rowRepetitionLevel = field.getRepetitionLevel(); - int nullValuesCount = 0; - BooleanArrayList nullRowFlags = new BooleanArrayList(0); - for (int i = 0; i < fieldDefinitionLevels.length; i++) { - if (fieldRepetitionLevels[i] > rowRepetitionLevel) { - throw new IllegalStateException( - format( - "In parquet's row type field repetition level should not larger than row's repetition level. " - + "Row repetition level is %s, row field repetition level is %s.", - rowRepetitionLevel, fieldRepetitionLevels[i])); - } - - if (fieldDefinitionLevels[i] >= rowDefinitionLevel) { - // current row is defined and not empty - nullRowFlags.add(false); - } else { - // current row is null - nullRowFlags.add(true); - nullValuesCount++; - } - } - if (nullValuesCount == 0) { - return new RowPosition(null, fieldDefinitionLevels.length); - } - return new RowPosition(nullRowFlags.toArray(), nullRowFlags.size()); - } - /** * Calculate the collection's offsets according to column's max repetition level, definition * level, value's repetition level and definition level. Each collection (Array or Map) has four @@ -92,7 +47,10 @@ public static RowPosition calculateRowOffsets( * array. */ public static CollectionPosition calculateCollectionOffsets( - ParquetField field, int[] definitionLevels, int[] repetitionLevels) { + ParquetField field, + int[] definitionLevels, + int[] repetitionLevels, + boolean readRowField) { int collectionDefinitionLevel = field.getDefinitionLevel(); int collectionRepetitionLevel = field.getRepetitionLevel() + 1; int offset = 0; @@ -110,7 +68,8 @@ public static CollectionPosition calculateCollectionOffsets( // empty // definitionLevels[i] == collectionDefinitionLevel => Collection is defined but // empty - // definitionLevels[i] == maxDefinitionLevel - 1 => Collection is defined but null + // definitionLevels[i] == collectionDefinitionLevel - 1 => Collection is defined but + // null if (definitionLevels[i] > collectionDefinitionLevel) { nullCollectionFlags.add(false); emptyCollectionFlags.add(false); @@ -127,6 +86,14 @@ public static CollectionPosition calculateCollectionOffsets( // must be set at each index for calculating lengths later emptyCollectionFlags.add(false); } + offsets.add(offset); + valueCount++; + } else if (definitionLevels[i] == collectionDefinitionLevel - 2 && readRowField) { + // row field should store null value + nullCollectionFlags.add(true); + nullValuesCount++; + emptyCollectionFlags.add(false); + offsets.add(offset); valueCount++; } diff --git a/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/NestedPrimitiveColumnReader.java b/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/NestedPrimitiveColumnReader.java index 7837eb8148e9..7ee33a0bb5cc 100644 --- a/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/NestedPrimitiveColumnReader.java +++ b/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/NestedPrimitiveColumnReader.java @@ -73,7 +73,8 @@ public class NestedPrimitiveColumnReader implements ColumnReader * *

    When (definitionLevel <= maxDefLevel - 2) we skip the value because children ColumnVector - * for OrcArrayColumnVector and OrcMapColumnVector don't contain empty and null set value. Stay - * consistent here. + * for OrcArrayColumnVector don't contain empty and null set value. Stay consistent here. * - *

    But notice that children of RowColumnVector still get null value when entire outer row is - * null, so when {@code parentIsRowType} is true the null value is still stored. + *

    For MAP, the value vector is the same as ARRAY. But the key vector isn't nullable, so just + * read value when definitionLevel == maxDefLevel. + * + *

    For ROW, RowColumnVector still get null value when definitionLevel == maxDefLevel - 2. */ private boolean readValue() throws IOException { int left = readPageIfNeed(); @@ -225,13 +229,19 @@ private boolean readValue() throws IOException { } else { lastValue.setValue(readPrimitiveTypedRow(dataType)); } - } else if (definitionLevel == maxDefLevel - 1) { - lastValue.setValue(null); } else { - if (parentIsRowType) { - lastValue.setValue(null); - } else { + if (readMapKey) { lastValue.skip(); + } else { + if (definitionLevel == maxDefLevel - 1) { + // null value inner set + lastValue.setValue(null); + } else if (definitionLevel == maxDefLevel - 2 && readRowField) { + lastValue.setValue(null); + } else { + // current set is empty or null + lastValue.skip(); + } } } return true; diff --git a/paimon-format/src/test/java/org/apache/paimon/format/parquet/ParquetColumnVectorTest.java b/paimon-format/src/test/java/org/apache/paimon/format/parquet/ParquetColumnVectorTest.java new file mode 100644 index 000000000000..e08e4f3ae19f --- /dev/null +++ b/paimon-format/src/test/java/org/apache/paimon/format/parquet/ParquetColumnVectorTest.java @@ -0,0 +1,782 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.format.parquet; + +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.DataGetters; +import org.apache.paimon.data.GenericArray; +import org.apache.paimon.data.GenericMap; +import org.apache.paimon.data.GenericRow; +import org.apache.paimon.data.InternalArray; +import org.apache.paimon.data.InternalMap; +import org.apache.paimon.data.InternalRow; +import org.apache.paimon.data.columnar.ArrayColumnVector; +import org.apache.paimon.data.columnar.BytesColumnVector; +import org.apache.paimon.data.columnar.ColumnVector; +import org.apache.paimon.data.columnar.IntColumnVector; +import org.apache.paimon.data.columnar.MapColumnVector; +import org.apache.paimon.data.columnar.RowColumnVector; +import org.apache.paimon.data.columnar.VectorizedColumnBatch; +import org.apache.paimon.format.FormatReaderContext; +import org.apache.paimon.format.FormatWriter; +import org.apache.paimon.format.parquet.writer.RowDataParquetBuilder; +import org.apache.paimon.fs.Path; +import org.apache.paimon.fs.local.LocalFileIO; +import org.apache.paimon.options.Options; +import org.apache.paimon.reader.RecordReader; +import org.apache.paimon.reader.VectorizedRecordIterator; +import org.apache.paimon.types.DataTypes; +import org.apache.paimon.types.RowType; +import org.apache.paimon.utils.StringUtils; + +import org.apache.parquet.filter2.compat.FilterCompat; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Random; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.BiFunction; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Validate the {@link ColumnVector}s read by Parquet format. */ +public class ParquetColumnVectorTest { + + private @TempDir java.nio.file.Path tempDir; + + private static final Random RND = ThreadLocalRandom.current(); + private static final BiFunction BYTES_COLUMN_VECTOR_STRING_FUNC = + (cv, i) -> + cv.isNullAt(i) + ? "null" + : new String(((BytesColumnVector) cv).getBytes(i).getBytes()); + + @Test + public void testArrayString() throws IOException { + RowType rowType = + RowType.builder() + .field("array_string", DataTypes.ARRAY(DataTypes.STRING())) + .build(); + + int numRows = RND.nextInt(5) + 5; + ArrayObject expectedData = new ArrayObject(); + List rows = new ArrayList<>(numRows); + for (int i = 0; i < numRows; i++) { + if (RND.nextBoolean()) { + expectedData.add(null); + rows.add(GenericRow.of((Object) null)); + continue; + } + + int currentSize = RND.nextInt(5); + List currentStringArray = + IntStream.range(0, currentSize) + .mapToObj(idx -> randomString()) + .collect(Collectors.toList()); + expectedData.add(currentStringArray); + GenericArray array = + new GenericArray( + currentStringArray.stream().map(BinaryString::fromString).toArray()); + rows.add(GenericRow.of(array)); + } + + VectorizedRecordIterator iterator = createVectorizedRecordIterator(rowType, rows); + VectorizedColumnBatch batch = iterator.batch(); + InternalArray.ElementGetter getter = InternalArray.createElementGetter(DataTypes.STRING()); + + // validate row by row + for (int i = 0; i < numRows; i++) { + InternalRow row = iterator.next(); + expectedData.validateRow(row, i, getter); + } + assertThat(iterator.next()).isNull(); + + // validate ColumnVector + ArrayColumnVector arrayColumnVector = (ArrayColumnVector) batch.columns[0]; + expectedData.validateColumnVector(arrayColumnVector, getter); + + expectedData.validateInnerChild( + arrayColumnVector.getColumnVector(), BYTES_COLUMN_VECTOR_STRING_FUNC); + + iterator.releaseBatch(); + } + + @Test + public void testArrayArrayString() throws IOException { + RowType rowType = + RowType.builder() + .field( + "array_array_string", + DataTypes.ARRAY(DataTypes.ARRAY(DataTypes.STRING()))) + .build(); + + int numRows = RND.nextInt(5) + 5; + ArrayArrayObject expectedData = new ArrayArrayObject(); + List rows = new ArrayList<>(numRows); + for (int i = 0; i < numRows; i++) { + // outer null row + if (RND.nextBoolean()) { + expectedData.add(null); + rows.add(GenericRow.of((Object) null)); + continue; + } + + int arraySize = RND.nextInt(5); + ArrayObject arrayObject = new ArrayObject(); + GenericArray[] innerArrays = new GenericArray[arraySize]; + for (int aIdx = 0; aIdx < arraySize; aIdx++) { + // inner null array + if (RND.nextBoolean()) { + arrayObject.add(null); + innerArrays[aIdx] = null; + continue; + } + + int arrayStringSize = RND.nextInt(5); + List currentStringArray = + IntStream.range(0, arrayStringSize) + .mapToObj(idx -> randomString()) + .collect(Collectors.toList()); + arrayObject.add(currentStringArray); + innerArrays[aIdx] = + new GenericArray( + currentStringArray.stream() + .map(BinaryString::fromString) + .toArray()); + } + expectedData.add(arrayObject); + rows.add(GenericRow.of(new GenericArray(innerArrays))); + } + + VectorizedRecordIterator iterator = createVectorizedRecordIterator(rowType, rows); + VectorizedColumnBatch batch = iterator.batch(); + InternalArray.ElementGetter getter = InternalArray.createElementGetter(DataTypes.STRING()); + + // validate row by row + for (int i = 0; i < numRows; i++) { + InternalRow row = iterator.next(); + expectedData.validateRow(row, i, getter); + } + assertThat(iterator.next()).isNull(); + + // validate column vector + ArrayColumnVector arrayColumnVector = (ArrayColumnVector) batch.columns[0]; + + expectedData.validateOuterArray(arrayColumnVector, getter); + + ArrayColumnVector innerArrayColumnVector = + (ArrayColumnVector) arrayColumnVector.getColumnVector(); + expectedData.validateInnerArray(innerArrayColumnVector, getter); + + ColumnVector columnVector = innerArrayColumnVector.getColumnVector(); + expectedData.validateInnerChild(columnVector, BYTES_COLUMN_VECTOR_STRING_FUNC); + } + + @Test + public void testMapString() throws IOException { + RowType rowType = + RowType.builder() + .field("map_string", DataTypes.MAP(DataTypes.INT(), DataTypes.STRING())) + .build(); + + int numRows = RND.nextInt(5) + 5; + ArrayObject expectedData = new ArrayObject(); + List rows = new ArrayList<>(numRows); + for (int i = 0; i < numRows; i++) { + if (RND.nextBoolean()) { + expectedData.add(null); + rows.add(GenericRow.of((Object) null)); + continue; + } + + int currentSize = RND.nextInt(5); + List currentStringArray = + IntStream.range(0, currentSize) + .mapToObj(idx -> randomString()) + .collect(Collectors.toList()); + expectedData.add(currentStringArray); + Map map = new HashMap<>(); + for (int idx = 0; idx < currentSize; idx++) { + map.put(idx, BinaryString.fromString(currentStringArray.get(idx))); + } + rows.add(GenericRow.of(new GenericMap(map))); + } + + VectorizedRecordIterator iterator = createVectorizedRecordIterator(rowType, rows); + VectorizedColumnBatch batch = iterator.batch(); + InternalArray.ElementGetter getter = InternalArray.createElementGetter(DataTypes.STRING()); + + // validate row by row + for (int i = 0; i < numRows; i++) { + InternalRow row = iterator.next(); + assertThat(row).isNotNull(); + List expected = expectedData.data.get(i); + if (expected == null) { + assertThat(row.isNullAt(0)).isTrue(); + } else { + InternalMap map = row.getMap(0); + validateMapKeyArray(map.keyArray()); + InternalArray valueArray = map.valueArray(); + expectedData.validateNonNullArray(expected, valueArray, getter); + } + } + assertThat(iterator.next()).isNull(); + + // validate ColumnVector + MapColumnVector mapColumnVector = (MapColumnVector) batch.columns[0]; + IntColumnVector keyColumnVector = (IntColumnVector) mapColumnVector.getKeyColumnVector(); + validateMapKeyColumnVector(keyColumnVector, expectedData); + ColumnVector valueColumnVector = mapColumnVector.getValueColumnVector(); + expectedData.validateInnerChild(valueColumnVector, BYTES_COLUMN_VECTOR_STRING_FUNC); + + iterator.releaseBatch(); + } + + @Test + public void testMapArrayString() throws IOException { + RowType rowType = + RowType.builder() + .field( + "map_array_string", + DataTypes.MAP(DataTypes.INT(), DataTypes.ARRAY(DataTypes.STRING()))) + .build(); + + int numRows = RND.nextInt(5) + 5; + ArrayArrayObject expectedData = new ArrayArrayObject(); + List rows = new ArrayList<>(numRows); + for (int i = 0; i < numRows; i++) { + // outer null row + if (RND.nextBoolean()) { + expectedData.add(null); + rows.add(GenericRow.of((Object) null)); + continue; + } + + int mapSize = RND.nextInt(5); + ArrayObject arrayObject = new ArrayObject(); + Map map = new HashMap<>(); + for (int mIdx = 0; mIdx < mapSize; mIdx++) { + // null array value + if (RND.nextBoolean()) { + arrayObject.add(null); + map.put(mIdx, null); + continue; + } + + int currentSize = RND.nextInt(5); + List currentStringArray = + IntStream.range(0, currentSize) + .mapToObj(idx -> randomString()) + .collect(Collectors.toList()); + arrayObject.add(currentStringArray); + + map.put( + mIdx, + new GenericArray( + currentStringArray.stream() + .map(BinaryString::fromString) + .toArray())); + } + expectedData.add(arrayObject); + rows.add(GenericRow.of(new GenericMap(map))); + } + + VectorizedRecordIterator iterator = createVectorizedRecordIterator(rowType, rows); + VectorizedColumnBatch batch = iterator.batch(); + InternalArray.ElementGetter getter = InternalArray.createElementGetter(DataTypes.STRING()); + + // validate row by row + for (int i = 0; i < numRows; i++) { + InternalRow row = iterator.next(); + assertThat(row).isNotNull(); + ArrayObject expected = expectedData.data.get(i); + if (expected == null) { + assertThat(row.isNullAt(0)).isTrue(); + } else { + InternalMap map = row.getMap(0); + validateMapKeyArray(map.keyArray()); + InternalArray valueArray = map.valueArray(); + expected.validateArrayGetter(valueArray, getter); + } + } + assertThat(iterator.next()).isNull(); + + // validate column vector + MapColumnVector mapColumnVector = (MapColumnVector) batch.columns[0]; + IntColumnVector keyColumnVector = (IntColumnVector) mapColumnVector.getKeyColumnVector(); + validateMapKeyColumnVector(keyColumnVector, expectedData); + + ArrayColumnVector valueColumnVector = + (ArrayColumnVector) mapColumnVector.getValueColumnVector(); + expectedData.validateInnerArray(valueColumnVector, getter); + expectedData.validateInnerChild( + valueColumnVector.getColumnVector(), BYTES_COLUMN_VECTOR_STRING_FUNC); + + iterator.releaseBatch(); + } + + private void validateMapKeyArray(InternalArray keyArray) { + for (int i = 0; i < keyArray.size(); i++) { + assertThat(keyArray.getInt(i)).isEqualTo(i); + } + } + + private void validateMapKeyColumnVector( + IntColumnVector columnVector, ArrayObject expectedData) { + int idx = 0; + for (List values : expectedData.data) { + if (values != null) { + for (int i = 0; i < values.size(); i++) { + assertThat(columnVector.getInt(idx++)).isEqualTo(i); + } + } + } + } + + private void validateMapKeyColumnVector( + IntColumnVector columnVector, ArrayArrayObject expectedData) { + int idx = 0; + for (ArrayObject arrayObject : expectedData.data) { + if (arrayObject != null) { + for (int i = 0; i < arrayObject.data.size(); i++) { + assertThat(columnVector.getInt(idx++)).isEqualTo(i); + } + } + } + } + + @Test + public void testRow() throws IOException { + RowType rowType = + RowType.builder() + .field( + "row", + RowType.builder() + .field("f0", DataTypes.INT()) + .field("f1", DataTypes.ARRAY(DataTypes.STRING())) + .build()) + .build(); + + int numRows = RND.nextInt(5) + 5; + ArrayObject expectedData = new ArrayObject(); + List rows = new ArrayList<>(numRows); + List f0 = new ArrayList<>(); + for (int i = 0; i < numRows; i++) { + if (RND.nextBoolean()) { + expectedData.add(null); + f0.add(null); + rows.add(GenericRow.of((Object) null)); + continue; + } + + if (RND.nextInt(5) == 0) { + // set f1 null + expectedData.add(null); + f0.add(i); + rows.add(GenericRow.of(GenericRow.of(i, null))); + continue; + } + + int currentSize = RND.nextInt(5); + List currentStringArray = + IntStream.range(0, currentSize) + .mapToObj(idx -> randomString()) + .collect(Collectors.toList()); + expectedData.add(currentStringArray); + f0.add(i); + GenericArray array = + new GenericArray( + currentStringArray.stream().map(BinaryString::fromString).toArray()); + rows.add(GenericRow.of(GenericRow.of(i, array))); + } + + VectorizedRecordIterator iterator = createVectorizedRecordIterator(rowType, rows); + VectorizedColumnBatch batch = iterator.batch(); + InternalArray.ElementGetter getter = InternalArray.createElementGetter(DataTypes.STRING()); + + // validate row by row + for (int i = 0; i < numRows; i++) { + InternalRow row = iterator.next(); + assertThat(row).isNotNull(); + if (f0.get(i) == null && expectedData.data.get(i) == null) { + assertThat(row.isNullAt(0)).isTrue(); + } else { + InternalRow innerRow = row.getRow(0, 2); + + if (f0.get(i) == null) { + assertThat(innerRow.isNullAt(0)).isTrue(); + } else { + assertThat(innerRow.getInt(0)).isEqualTo(f0.get(i)); + } + + if (expectedData.data.get(i) == null) { + assertThat(innerRow.isNullAt(1)).isTrue(); + } else { + expectedData.validateNonNullArray( + expectedData.data.get(i), innerRow.getArray(1), getter); + } + } + } + assertThat(iterator.next()).isNull(); + + // validate ColumnVector + RowColumnVector rowColumnVector = (RowColumnVector) batch.columns[0]; + VectorizedColumnBatch innerBatch = rowColumnVector.getBatch(); + + IntColumnVector intColumnVector = (IntColumnVector) innerBatch.columns[0]; + for (int i = 0; i < numRows; i++) { + Integer f0Value = f0.get(i); + if (f0Value == null) { + assertThat(intColumnVector.isNullAt(i)).isTrue(); + } else { + assertThat(intColumnVector.getInt(i)).isEqualTo(f0Value); + } + } + + ArrayColumnVector arrayColumnVector = (ArrayColumnVector) innerBatch.columns[1]; + expectedData.validateColumnVector(arrayColumnVector, getter); + expectedData.validateInnerChild( + arrayColumnVector.getColumnVector(), BYTES_COLUMN_VECTOR_STRING_FUNC); + + iterator.releaseBatch(); + } + + @Test + public void testArrayRowArray() throws IOException { + RowType rowType = + RowType.builder() + .field( + "array_row_array", + DataTypes.ARRAY( + RowType.builder() + .field("f0", DataTypes.STRING()) + .field("f1", DataTypes.ARRAY(DataTypes.INT())) + .build())) + .build(); + + List rows = new ArrayList<>(4); + List f0 = new ArrayList<>(3); + for (int i = 0; i < 3; i++) { + f0.add(BinaryString.fromString(randomString())); + } + + GenericRow row00 = GenericRow.of(f0.get(0), new GenericArray(new Object[] {0, null})); + GenericRow row01 = GenericRow.of(f0.get(1), new GenericArray(new Object[] {})); + GenericArray array0 = new GenericArray(new GenericRow[] {row00, row01}); + rows.add(GenericRow.of(array0)); + + rows.add(GenericRow.of((Object) null)); + + GenericRow row20 = GenericRow.of(f0.get(2), new GenericArray(new Object[] {1})); + GenericArray array2 = new GenericArray(new GenericRow[] {row20}); + rows.add(GenericRow.of(array2)); + + GenericArray array3 = new GenericArray(new GenericRow[] {}); + rows.add(GenericRow.of(array3)); + + VectorizedRecordIterator iterator = createVectorizedRecordIterator(rowType, rows); + VectorizedColumnBatch batch = iterator.batch(); + + // validate row by row + InternalRow row0 = iterator.next(); + // array0 + InternalArray array = row0.getArray(0); + assertThat(array.size()).isEqualTo(2); + // row00 + InternalRow row = array.getRow(0, 1); + if (f0.get(0) == null) { + assertThat(row.isNullAt(0)).isTrue(); + } else { + assertThat(row.getString(0)).isEqualTo(f0.get(0)); + } + InternalArray innerArray = row.getArray(1); + assertThat(innerArray.size()).isEqualTo(2); + assertThat(innerArray.getInt(0)).isEqualTo(0); + assertThat(innerArray.isNullAt(1)).isTrue(); + // row01 + row = array.getRow(1, 1); + if (f0.get(1) == null) { + assertThat(row.isNullAt(0)).isTrue(); + } else { + assertThat(row.getString(0)).isEqualTo(f0.get(1)); + } + innerArray = row.getArray(1); + assertThat(innerArray.size()).isEqualTo(0); + + InternalRow row1 = iterator.next(); + assertThat(row1.isNullAt(0)).isTrue(); + + InternalRow row2 = iterator.next(); + // array2 + array = row2.getArray(0); + assertThat(array.size()).isEqualTo(1); + // row20 + row = array.getRow(0, 1); + if (f0.get(2) == null) { + assertThat(row.isNullAt(0)).isTrue(); + } else { + assertThat(row.getString(0)).isEqualTo(f0.get(2)); + } + innerArray = row.getArray(1); + assertThat(innerArray.size()).isEqualTo(1); + assertThat(innerArray.getInt(0)).isEqualTo(1); + + InternalRow row3 = iterator.next(); + // array2 + array = row3.getArray(0); + assertThat(array.size()).isEqualTo(0); + + assertThat(iterator.next()).isNull(); + + // validate ColumnVector + ArrayColumnVector arrayColumnVector = (ArrayColumnVector) batch.columns[0]; + assertThat(arrayColumnVector.isNullAt(0)).isFalse(); + assertThat(arrayColumnVector.isNullAt(1)).isTrue(); + assertThat(arrayColumnVector.isNullAt(2)).isFalse(); + assertThat(arrayColumnVector.isNullAt(3)).isFalse(); + + RowColumnVector rowColumnVector = (RowColumnVector) arrayColumnVector.getColumnVector(); + BytesColumnVector f0Vector = (BytesColumnVector) rowColumnVector.getBatch().columns[0]; + for (int i = 0; i < 3; i++) { + BinaryString s = f0.get(i); + if (s == null) { + assertThat(f0Vector.isNullAt(i)).isTrue(); + } else { + assertThat(new String(f0Vector.getBytes(i).getBytes())).isEqualTo(s.toString()); + } + } + ArrayColumnVector f1Vector = (ArrayColumnVector) rowColumnVector.getBatch().columns[1]; + InternalArray internalArray0 = f1Vector.getArray(0); + assertThat(internalArray0.size()).isEqualTo(2); + assertThat(internalArray0.isNullAt(0)).isFalse(); + assertThat(internalArray0.isNullAt(1)).isTrue(); + + InternalArray internalArray1 = f1Vector.getArray(1); + assertThat(internalArray1.size()).isEqualTo(0); + + InternalArray internalArray2 = f1Vector.getArray(2); + assertThat(internalArray2.size()).isEqualTo(1); + assertThat(internalArray2.isNullAt(0)).isFalse(); + + IntColumnVector intColumnVector = (IntColumnVector) f1Vector.getColumnVector(); + assertThat(intColumnVector.getInt(0)).isEqualTo(0); + assertThat(intColumnVector.isNullAt(1)).isTrue(); + assertThat(intColumnVector.getInt(2)).isEqualTo(1); + + iterator.releaseBatch(); + } + + private VectorizedRecordIterator createVectorizedRecordIterator( + RowType rowType, List rows) throws IOException { + Path path = new Path(tempDir.toString(), UUID.randomUUID().toString()); + LocalFileIO fileIO = LocalFileIO.create(); + + ParquetWriterFactory writerFactory = + new ParquetWriterFactory(new RowDataParquetBuilder(rowType, new Options())); + FormatWriter writer = writerFactory.create(fileIO.newOutputStream(path, false), "zstd"); + for (InternalRow row : rows) { + writer.addElement(row); + } + writer.flush(); + writer.finish(); + + ParquetReaderFactory readerFactory = + new ParquetReaderFactory(new Options(), rowType, 1024, FilterCompat.NOOP); + + RecordReader reader = + readerFactory.createReader( + new FormatReaderContext(fileIO, path, fileIO.getFileSize(path))); + + RecordReader.RecordIterator iterator = reader.readBatch(); + return (VectorizedRecordIterator) iterator; + } + + @Nullable + private String randomString() { + return RND.nextInt(5) == 0 ? null : StringUtils.getRandomString(RND, 1, 10); + } + + /** Store generated data of ARRAY[STRING] and provide validated methods. */ + private static class ArrayObject { + + public final List> data; + + public ArrayObject() { + this.data = new ArrayList<>(); + } + + public void add(List objects) { + data.add(objects); + } + + public void validateRow(InternalRow row, int i, InternalArray.ElementGetter getter) { + assertThat(row).isNotNull(); + List expected = data.get(i); + if (expected == null) { + assertThat(row.isNullAt(0)).isTrue(); + } else { + validateNonNullArray(expected, row.getArray(0), getter); + } + } + + public void validateColumnVector( + ArrayColumnVector arrayColumnVector, InternalArray.ElementGetter getter) { + for (int i = 0; i < data.size(); i++) { + List expected = data.get(i); + if (expected == null) { + assertThat(arrayColumnVector.isNullAt(i)).isTrue(); + } else { + validateNonNullArray(expected, arrayColumnVector.getArray(i), getter); + } + } + } + + public void validateArrayGetter(DataGetters arrays, InternalArray.ElementGetter getter) { + for (int i = 0; i < data.size(); i++) { + List expected = data.get(i); + if (expected == null) { + assertThat(arrays.isNullAt(i)).isTrue(); + } else { + validateNonNullArray(expected, arrays.getArray(i), getter); + } + } + } + + public void validateNonNullArray( + List expected, InternalArray array, InternalArray.ElementGetter getter) { + int arraySize = array.size(); + assertThat(arraySize).isEqualTo(expected.size()); + for (int i = 0; i < arraySize; i++) { + String value = String.valueOf(getter.getElementOrNull(array, i)); + assertThat(value).isEqualTo(String.valueOf(expected.get(i))); + } + } + + public void validateInnerChild( + ColumnVector columnVector, BiFunction stringGetter) { + // it doesn't contain null rows + List expandedData = + data.stream() + .filter(Objects::nonNull) + .flatMap(Collection::stream) + .collect(Collectors.toList()); + for (int i = 0; i < expandedData.size(); i++) { + assertThat(stringGetter.apply(columnVector, i)) + .isEqualTo(String.valueOf(expandedData.get(i))); + } + } + } + + /** Store generated data of ARRAY[ARRAY[STRING]] and provide validated methods. */ + private static class ArrayArrayObject { + + public final List data; + + public ArrayArrayObject() { + this.data = new ArrayList<>(); + } + + public void add(@Nullable ArrayObject arrayObjects) { + data.add(arrayObjects); + } + + private List> expand() { + // it doesn't contain null rows of outer array + return data.stream() + .filter(Objects::nonNull) + .flatMap(i -> i.data.stream()) + .collect(Collectors.toList()); + } + + private List expandInner() { + // it doesn't contain null rows of outer and inner array + return expand().stream() + .filter(Objects::nonNull) + .flatMap(Collection::stream) + .collect(Collectors.toList()); + } + + public void validateRow(InternalRow row, int i, InternalArray.ElementGetter getter) { + assertThat(row).isNotNull(); + ArrayObject expectedArray = data.get(i); + if (expectedArray == null) { + assertThat(row.isNullAt(0)).isTrue(); + } else { + InternalArray outerArray = row.getArray(0); + assertThat(outerArray.size()).isEqualTo(expectedArray.data.size()); + expectedArray.validateArrayGetter(outerArray, getter); + } + } + + public void validateOuterArray( + ArrayColumnVector arrayColumnVector, + InternalArray.ElementGetter innerElementGetter) { + for (int i = 0; i < data.size(); i++) { + ArrayObject expected = data.get(i); + if (expected == null) { + assertThat(arrayColumnVector.isNullAt(i)).isTrue(); + } else { + InternalArray array = arrayColumnVector.getArray(i); + expected.validateArrayGetter(array, innerElementGetter); + } + } + } + + public void validateInnerArray( + ArrayColumnVector arrayColumnVector, + InternalArray.ElementGetter innerElementGetter) { + List> expandedData = expand(); + for (int i = 0; i < expandedData.size(); i++) { + List expected = expandedData.get(i); + if (expected == null) { + assertThat(arrayColumnVector.isNullAt(i)).isTrue(); + } else { + InternalArray array = arrayColumnVector.getArray(i); + int size = array.size(); + assertThat(size).isEqualTo(expected.size()); + for (int j = 0; j < size; j++) { + assertThat(String.valueOf(innerElementGetter.getElementOrNull(array, j))) + .isEqualTo(String.valueOf(expected.get(j))); + } + } + } + } + + public void validateInnerChild( + ColumnVector columnVector, BiFunction stringGetter) { + List expandedData = expandInner(); + for (int i = 0; i < expandedData.size(); i++) { + assertThat(stringGetter.apply(columnVector, i)) + .isEqualTo(String.valueOf(expandedData.get(i))); + } + } + } +}