Skip to content

Commit

Permalink
[arrow] Arrow reader should check schema match every time (apache#4031)
Browse files Browse the repository at this point in the history
  • Loading branch information
leaves12138 authored Aug 22, 2024
1 parent 1d84aa5 commit 88f7c71
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,36 +23,55 @@
import org.apache.paimon.data.columnar.ColumnVector;
import org.apache.paimon.data.columnar.ColumnarRow;
import org.apache.paimon.data.columnar.VectorizedColumnBatch;
import org.apache.paimon.types.DataField;
import org.apache.paimon.types.RowType;

import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;

import java.util.Iterator;
import java.util.List;

/** Reader from a {@link VectorSchemaRoot} to paimon rows. */
public class ArrowBatchReader {

private final VectorizedColumnBatch batch;
private final Arrow2PaimonVectorConverter[] convertors;
private final RowType projectedRowType;

public ArrowBatchReader(RowType rowType) {
ColumnVector[] columnVectors = new ColumnVector[rowType.getFieldCount()];
this.convertors = new Arrow2PaimonVectorConverter[rowType.getFieldCount()];
this.batch = new VectorizedColumnBatch(columnVectors);
this.projectedRowType = rowType;

for (int i = 0; i < columnVectors.length; i++) {
this.convertors[i] = Arrow2PaimonVectorConverter.construct(rowType.getTypeAt(i));
}
}

public Iterable<InternalRow> readBatch(VectorSchemaRoot vsr) {
int[] mapping = new int[projectedRowType.getFieldCount()];
Schema arrowSchema = vsr.getSchema();
List<DataField> dataFields = projectedRowType.getFields();
for (int i = 0; i < dataFields.size(); ++i) {
try {
Field field = arrowSchema.findField(dataFields.get(i).name().toLowerCase());
int idx = arrowSchema.getFields().indexOf(field);
mapping[i] = idx;
} catch (IllegalArgumentException e) {
throw new RuntimeException(e);
}
}

for (int i = 0; i < batch.columns.length; i++) {
batch.columns[i] = convertors[i].convertVector(vsr.getVector(i));
batch.columns[i] = convertors[i].convertVector(vsr.getVector(mapping[i]));
}
int rowCount = vsr.getRowCount();

int rowCount = vsr.getRowCount();
batch.setNumRows(vsr.getRowCount());
ColumnarRow columnarRow = new ColumnarRow(batch);
final ColumnarRow columnarRow = new ColumnarRow(batch);
return () ->
new Iterator<InternalRow>() {
private int position = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.paimon.types.RowType;
import org.apache.paimon.utils.StringUtils;

import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -114,6 +115,48 @@ public void testWrite() {
}
}

@Test
public void testReadWithSchemaMessUp() {
try (ArrowFormatWriter writer = new ArrowFormatWriter(PRIMITIVE_TYPE, 4096)) {
List<InternalRow> list = new ArrayList<>();
List<InternalRow.FieldGetter> fieldGetters = new ArrayList<>();

for (int i = 0; i < PRIMITIVE_TYPE.getFieldCount(); i++) {
fieldGetters.add(InternalRow.createFieldGetter(PRIMITIVE_TYPE.getTypeAt(i), i));
}
for (int i = 0; i < 1000; i++) {
list.add(GenericRow.of(randomRowValues(null)));
}

list.forEach(writer::write);

writer.flush();
VectorSchemaRoot vectorSchemaRoot = writer.getVectorSchemaRoot();

// mess up the fields
List<FieldVector> vectors = vectorSchemaRoot.getFieldVectors();
FieldVector vector0 = vectors.get(0);
for (int i = 0; i < vectors.size() - 1; i++) {
vectors.set(i, vectors.get(i + 1));
}
vectors.set(vectors.size() - 1, vector0);

ArrowBatchReader arrowBatchReader = new ArrowBatchReader(PRIMITIVE_TYPE);
Iterable<InternalRow> rows = arrowBatchReader.readBatch(new VectorSchemaRoot(vectors));

Iterator<InternalRow> iterator = rows.iterator();
for (int i = 0; i < 1000; i++) {
InternalRow actual = iterator.next();
InternalRow expectec = list.get(i);

for (InternalRow.FieldGetter fieldGetter : fieldGetters) {
Assertions.assertThat(fieldGetter.getFieldOrNull(actual))
.isEqualTo(fieldGetter.getFieldOrNull(expectec));
}
}
}
}

private Object[] randomRowValues(boolean[] nullable) {
Object[] values = new Object[18];
values[0] = BinaryString.fromString(StringUtils.getRandomString(RND, 10, 10));
Expand Down

0 comments on commit 88f7c71

Please sign in to comment.