Skip to content

Commit

Permalink
Arrow: Fix indexing in Parquet dictionary encoded values readers (apa…
Browse files Browse the repository at this point in the history
  • Loading branch information
wypoon authored Oct 21, 2024
1 parent c16cefa commit d0a7ff9
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,10 @@ public void nextBatch(
}
int numValues = Math.min(left, currentCount);
for (int i = 0; i < numValues; i++) {
int index = idx * typeWidth;
if (typeWidth == -1) {
index = idx;
}
if (Mode.RLE.equals(mode)) {
nextVal(vector, dict, index, currentValue, typeWidth);
nextVal(vector, dict, idx, currentValue, typeWidth);
} else if (Mode.PACKED.equals(mode)) {
nextVal(vector, dict, index, packedValuesBuffer[packedValuesBufferIdx++], typeWidth);
nextVal(vector, dict, idx, packedValuesBuffer[packedValuesBufferIdx++], typeWidth);
}
nullabilityHolder.setNotNull(idx);
if (setArrowValidityVector) {
Expand Down Expand Up @@ -94,15 +90,15 @@ class LongDictEncodedReader extends BaseDictEncodedReader {
@Override
protected void nextVal(
FieldVector vector, Dictionary dict, int idx, int currentVal, int typeWidth) {
vector.getDataBuffer().setLong(idx, dict.decodeToLong(currentVal));
vector.getDataBuffer().setLong((long) idx * typeWidth, dict.decodeToLong(currentVal));
}
}

class TimestampMillisDictEncodedReader extends BaseDictEncodedReader {
@Override
protected void nextVal(
FieldVector vector, Dictionary dict, int idx, int currentVal, int typeWidth) {
vector.getDataBuffer().setLong(idx, dict.decodeToLong(currentVal) * 1000);
vector.getDataBuffer().setLong((long) idx * typeWidth, dict.decodeToLong(currentVal) * 1000);
}
}

Expand All @@ -113,31 +109,31 @@ protected void nextVal(
ByteBuffer buffer =
dict.decodeToBinary(currentVal).toByteBuffer().order(ByteOrder.LITTLE_ENDIAN);
long timestampInt96 = ParquetUtil.extractTimestampInt96(buffer);
vector.getDataBuffer().setLong(idx, timestampInt96);
vector.getDataBuffer().setLong((long) idx * typeWidth, timestampInt96);
}
}

class IntegerDictEncodedReader extends BaseDictEncodedReader {
@Override
protected void nextVal(
FieldVector vector, Dictionary dict, int idx, int currentVal, int typeWidth) {
vector.getDataBuffer().setInt(idx, dict.decodeToInt(currentVal));
vector.getDataBuffer().setInt((long) idx * typeWidth, dict.decodeToInt(currentVal));
}
}

class FloatDictEncodedReader extends BaseDictEncodedReader {
@Override
protected void nextVal(
FieldVector vector, Dictionary dict, int idx, int currentVal, int typeWidth) {
vector.getDataBuffer().setFloat(idx, dict.decodeToFloat(currentVal));
vector.getDataBuffer().setFloat((long) idx * typeWidth, dict.decodeToFloat(currentVal));
}
}

class DoubleDictEncodedReader extends BaseDictEncodedReader {
@Override
protected void nextVal(
FieldVector vector, Dictionary dict, int idx, int currentVal, int typeWidth) {
vector.getDataBuffer().setDouble(idx, dict.decodeToDouble(currentVal));
vector.getDataBuffer().setDouble((long) idx * typeWidth, dict.decodeToDouble(currentVal));
}
}

Expand All @@ -150,7 +146,7 @@ class FixedWidthBinaryDictEncodedReader extends BaseDictEncodedReader {
protected void nextVal(
FieldVector vector, Dictionary dict, int idx, int currentVal, int typeWidth) {
ByteBuffer buffer = dict.decodeToBinary(currentVal).toByteBuffer();
vector.getDataBuffer().setBytes(idx, buffer);
vector.getDataBuffer().setBytes((long) idx * typeWidth, buffer);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,21 @@ public static void assertEqualsBatch(
}
}

public static void assertEqualsBatchWithRows(
Types.StructType struct, Iterator<Row> expected, ColumnarBatch batch) {
for (int rowId = 0; rowId < batch.numRows(); rowId++) {
List<Types.NestedField> fields = struct.fields();
InternalRow row = batch.getRow(rowId);
Row expectedRow = expected.next();
for (int i = 0; i < fields.size(); i += 1) {
Type fieldType = fields.get(i).type();
Object expectedValue = expectedRow.get(i);
Object actualValue = row.isNullAt(i) ? null : row.get(i, convert(fieldType));
assertEqualsUnsafe(fieldType, expectedValue, actualValue);
}
}
}

private static void assertEqualsSafe(Types.ListType list, Collection<?> expected, List actual) {
Type elementType = list.elementType();
List<?> expectedElements = Lists.newArrayList(expected);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,21 @@
*/
package org.apache.iceberg.spark.data.parquet.vectorized;

import static org.apache.iceberg.TableProperties.PARQUET_DICT_SIZE_BYTES;
import static org.apache.iceberg.TableProperties.PARQUET_PAGE_ROW_LIMIT;
import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT;
import static org.assertj.core.api.Assertions.assertThat;

import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Iterator;
import java.util.List;
import org.apache.avro.generic.GenericData;
import org.apache.iceberg.Files;
import org.apache.iceberg.Schema;
import org.apache.iceberg.io.CloseableIterable;
import org.apache.iceberg.io.FileAppender;
import org.apache.iceberg.parquet.Parquet;
import org.apache.iceberg.relocated.com.google.common.base.Function;
Expand All @@ -33,11 +41,35 @@
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
import org.apache.iceberg.spark.data.RandomData;
import org.apache.iceberg.spark.data.TestHelpers;
import org.apache.iceberg.spark.data.vectorized.VectorizedSparkParquetReaders;
import org.apache.iceberg.types.Types;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

public class TestParquetDictionaryEncodedVectorizedReads extends TestParquetVectorizedReads {

protected static SparkSession spark = null;

@BeforeAll
public static void startSpark() {
spark = SparkSession.builder().master("local[2]").getOrCreate();
}

@AfterAll
public static void stopSpark() {
if (spark != null) {
spark.stop();
spark = null;
}
}

@Override
Iterable<GenericData.Record> generateData(
Schema schema,
Expand Down Expand Up @@ -93,4 +125,64 @@ public void testMixedDictionaryNonDictionaryReads() throws IOException {
true,
BATCH_SIZE);
}

@Test
public void testBinaryNotAllPagesDictionaryEncoded() throws IOException {
Schema schema = new Schema(Types.NestedField.required(1, "bytes", Types.BinaryType.get()));
File parquetFile = File.createTempFile("junit", null, temp.toFile());
assertThat(parquetFile.delete()).as("Delete should succeed").isTrue();

Iterable<GenericData.Record> records = RandomData.generateFallbackData(schema, 500, 0L, 100);
try (FileAppender<GenericData.Record> writer =
Parquet.write(Files.localOutput(parquetFile))
.schema(schema)
.set(PARQUET_DICT_SIZE_BYTES, "4096")
.set(PARQUET_PAGE_ROW_LIMIT, "100")
.build()) {
writer.addAll(records);
}

// After the above, parquetFile contains one column chunk of binary data in five pages,
// the first two RLE dictionary encoded, and the remaining three plain encoded.
assertRecordsMatch(schema, 500, records, parquetFile, true, BATCH_SIZE);
}

/**
* decimal_dict_and_plain_encoding.parquet contains one column chunk of decimal(38, 0) data in two
* pages, one RLE dictionary encoded and one plain encoded, each with 200 rows.
*/
@Test
public void testDecimalNotAllPagesDictionaryEncoded() throws Exception {
Schema schema = new Schema(Types.NestedField.required(1, "id", Types.DecimalType.of(38, 0)));
Path path =
Paths.get(
getClass()
.getClassLoader()
.getResource("decimal_dict_and_plain_encoding.parquet")
.toURI());

Dataset<Row> df = spark.read().parquet(path.toString());
List<Row> expected = df.collectAsList();
long expectedSize = df.count();

Parquet.ReadBuilder readBuilder =
Parquet.read(Files.localInput(path.toFile()))
.project(schema)
.createBatchedReaderFunc(
type ->
VectorizedSparkParquetReaders.buildReader(
schema, type, ImmutableMap.of(), null));

try (CloseableIterable<ColumnarBatch> batchReader = readBuilder.build()) {
Iterator<Row> expectedIter = expected.iterator();
Iterator<ColumnarBatch> batches = batchReader.iterator();
int numRowsRead = 0;
while (batches.hasNext()) {
ColumnarBatch batch = batches.next();
numRowsRead += batch.numRows();
TestHelpers.assertEqualsBatchWithRows(schema.asStruct(), expectedIter, batch);
}
assertThat(numRowsRead).isEqualTo(expectedSize);
}
}
}
Binary file not shown.

0 comments on commit d0a7ff9

Please sign in to comment.