From 88f7c71497c7d14a424273ff1ab39f66d4c1620f Mon Sep 17 00:00:00 2001 From: YeJunHao <41894543+leaves12138@users.noreply.github.com> Date: Thu, 22 Aug 2024 14:32:19 +0800 Subject: [PATCH] [arrow] Arrow reader should check schema match every time (#4031) --- .../paimon/arrow/reader/ArrowBatchReader.java | 25 +++++++++-- .../arrow/vector/ArrowFormatWriterTest.java | 43 +++++++++++++++++++ 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/paimon-arrow/src/main/java/org/apache/paimon/arrow/reader/ArrowBatchReader.java b/paimon-arrow/src/main/java/org/apache/paimon/arrow/reader/ArrowBatchReader.java index ec76cd95d003..b8d32e57d442 100644 --- a/paimon-arrow/src/main/java/org/apache/paimon/arrow/reader/ArrowBatchReader.java +++ b/paimon-arrow/src/main/java/org/apache/paimon/arrow/reader/ArrowBatchReader.java @@ -23,22 +23,28 @@ import org.apache.paimon.data.columnar.ColumnVector; import org.apache.paimon.data.columnar.ColumnarRow; import org.apache.paimon.data.columnar.VectorizedColumnBatch; +import org.apache.paimon.types.DataField; import org.apache.paimon.types.RowType; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; import java.util.Iterator; +import java.util.List; /** Reader from a {@link VectorSchemaRoot} to paimon rows. */ public class ArrowBatchReader { private final VectorizedColumnBatch batch; private final Arrow2PaimonVectorConverter[] convertors; + private final RowType projectedRowType; public ArrowBatchReader(RowType rowType) { ColumnVector[] columnVectors = new ColumnVector[rowType.getFieldCount()]; this.convertors = new Arrow2PaimonVectorConverter[rowType.getFieldCount()]; this.batch = new VectorizedColumnBatch(columnVectors); + this.projectedRowType = rowType; for (int i = 0; i < columnVectors.length; i++) { this.convertors[i] = Arrow2PaimonVectorConverter.construct(rowType.getTypeAt(i)); @@ -46,13 +52,26 @@ public ArrowBatchReader(RowType rowType) { } public Iterable readBatch(VectorSchemaRoot vsr) { + int[] mapping = new int[projectedRowType.getFieldCount()]; + Schema arrowSchema = vsr.getSchema(); + List dataFields = projectedRowType.getFields(); + for (int i = 0; i < dataFields.size(); ++i) { + try { + Field field = arrowSchema.findField(dataFields.get(i).name().toLowerCase()); + int idx = arrowSchema.getFields().indexOf(field); + mapping[i] = idx; + } catch (IllegalArgumentException e) { + throw new RuntimeException(e); + } + } + for (int i = 0; i < batch.columns.length; i++) { - batch.columns[i] = convertors[i].convertVector(vsr.getVector(i)); + batch.columns[i] = convertors[i].convertVector(vsr.getVector(mapping[i])); } - int rowCount = vsr.getRowCount(); + int rowCount = vsr.getRowCount(); batch.setNumRows(vsr.getRowCount()); - ColumnarRow columnarRow = new ColumnarRow(batch); + final ColumnarRow columnarRow = new ColumnarRow(batch); return () -> new Iterator() { private int position = 0; diff --git a/paimon-arrow/src/test/java/org/apache/paimon/arrow/vector/ArrowFormatWriterTest.java b/paimon-arrow/src/test/java/org/apache/paimon/arrow/vector/ArrowFormatWriterTest.java index 8d7b5b7e6fd5..dc3b197a33e0 100644 --- a/paimon-arrow/src/test/java/org/apache/paimon/arrow/vector/ArrowFormatWriterTest.java +++ b/paimon-arrow/src/test/java/org/apache/paimon/arrow/vector/ArrowFormatWriterTest.java @@ -29,6 +29,7 @@ import org.apache.paimon.types.RowType; import org.apache.paimon.utils.StringUtils; +import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; @@ -114,6 +115,48 @@ public void testWrite() { } } + @Test + public void testReadWithSchemaMessUp() { + try (ArrowFormatWriter writer = new ArrowFormatWriter(PRIMITIVE_TYPE, 4096)) { + List list = new ArrayList<>(); + List fieldGetters = new ArrayList<>(); + + for (int i = 0; i < PRIMITIVE_TYPE.getFieldCount(); i++) { + fieldGetters.add(InternalRow.createFieldGetter(PRIMITIVE_TYPE.getTypeAt(i), i)); + } + for (int i = 0; i < 1000; i++) { + list.add(GenericRow.of(randomRowValues(null))); + } + + list.forEach(writer::write); + + writer.flush(); + VectorSchemaRoot vectorSchemaRoot = writer.getVectorSchemaRoot(); + + // mess up the fields + List vectors = vectorSchemaRoot.getFieldVectors(); + FieldVector vector0 = vectors.get(0); + for (int i = 0; i < vectors.size() - 1; i++) { + vectors.set(i, vectors.get(i + 1)); + } + vectors.set(vectors.size() - 1, vector0); + + ArrowBatchReader arrowBatchReader = new ArrowBatchReader(PRIMITIVE_TYPE); + Iterable rows = arrowBatchReader.readBatch(new VectorSchemaRoot(vectors)); + + Iterator iterator = rows.iterator(); + for (int i = 0; i < 1000; i++) { + InternalRow actual = iterator.next(); + InternalRow expectec = list.get(i); + + for (InternalRow.FieldGetter fieldGetter : fieldGetters) { + Assertions.assertThat(fieldGetter.getFieldOrNull(actual)) + .isEqualTo(fieldGetter.getFieldOrNull(expectec)); + } + } + } + } + private Object[] randomRowValues(boolean[] nullable) { Object[] values = new Object[18]; values[0] = BinaryString.fromString(StringUtils.getRandomString(RND, 10, 10));