Skip to content

Commit

Permalink
Spark 3.5: Support default values in vectorized reads (apache#11815)
Browse files Browse the repository at this point in the history
  • Loading branch information
rdblue authored Dec 19, 2024
1 parent 3535240 commit 7033667
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.IntStream;
import org.apache.arrow.memory.BufferAllocator;
Expand Down Expand Up @@ -47,13 +48,30 @@ public class VectorizedReaderBuilder extends TypeWithSchemaVisitor<VectorizedRea
private final Map<Integer, ?> idToConstant;
private final boolean setArrowValidityVector;
private final Function<List<VectorizedReader<?>>, VectorizedReader<?>> readerFactory;
private final BiFunction<org.apache.iceberg.types.Type, Object, Object> convert;

public VectorizedReaderBuilder(
Schema expectedSchema,
MessageType parquetSchema,
boolean setArrowValidityVector,
Map<Integer, ?> idToConstant,
Function<List<VectorizedReader<?>>, VectorizedReader<?>> readerFactory) {
this(
expectedSchema,
parquetSchema,
setArrowValidityVector,
idToConstant,
readerFactory,
(type, value) -> value);
}

protected VectorizedReaderBuilder(
Schema expectedSchema,
MessageType parquetSchema,
boolean setArrowValidityVector,
Map<Integer, ?> idToConstant,
Function<List<VectorizedReader<?>>, VectorizedReader<?>> readerFactory,
BiFunction<org.apache.iceberg.types.Type, Object, Object> convert) {
this.parquetSchema = parquetSchema;
this.icebergSchema = expectedSchema;
this.rootAllocator =
Expand All @@ -62,6 +80,7 @@ public VectorizedReaderBuilder(
this.setArrowValidityVector = setArrowValidityVector;
this.idToConstant = idToConstant;
this.readerFactory = readerFactory;
this.convert = convert;
}

@Override
Expand All @@ -85,7 +104,7 @@ public VectorizedReader<?> message(
int id = field.fieldId();
VectorizedReader<?> reader = readersById.get(id);
if (idToConstant.containsKey(id)) {
reorderedFields.add(new ConstantVectorReader<>(field, idToConstant.get(id)));
reorderedFields.add(constantReader(field, idToConstant.get(id)));
} else if (id == MetadataColumns.ROW_POSITION.fieldId()) {
if (setArrowValidityVector) {
reorderedFields.add(VectorizedArrowReader.positionsWithSetArrowValidityVector());
Expand All @@ -96,13 +115,23 @@ public VectorizedReader<?> message(
reorderedFields.add(new DeletedVectorReader());
} else if (reader != null) {
reorderedFields.add(reader);
} else {
} else if (field.initialDefault() != null) {
reorderedFields.add(
constantReader(field, convert.apply(field.type(), field.initialDefault())));
} else if (field.isOptional()) {
reorderedFields.add(VectorizedArrowReader.nulls());
} else {
throw new IllegalArgumentException(
String.format("Missing required field: %s", field.name()));
}
}
return vectorizedReader(reorderedFields);
}

private <T> ConstantVectorReader<T> constantReader(Types.NestedField field, T constant) {
return new ConstantVectorReader<>(field, constant);
}

protected VectorizedReader<?> vectorizedReader(List<VectorizedReader<?>> reorderedFields) {
return readerFactory.apply(reorderedFields);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.iceberg.data.DeleteFilter;
import org.apache.iceberg.parquet.TypeWithSchemaVisitor;
import org.apache.iceberg.parquet.VectorizedReader;
import org.apache.iceberg.spark.SparkUtil;
import org.apache.parquet.schema.MessageType;
import org.apache.spark.sql.catalyst.InternalRow;
import org.slf4j.Logger;
Expand Down Expand Up @@ -112,7 +113,13 @@ private static class ReaderBuilder extends VectorizedReaderBuilder {
Map<Integer, ?> idToConstant,
Function<List<VectorizedReader<?>>, VectorizedReader<?>> readerFactory,
DeleteFilter<InternalRow> deleteFilter) {
super(expectedSchema, parquetSchema, setArrowValidityVector, idToConstant, readerFactory);
super(
expectedSchema,
parquetSchema,
setArrowValidityVector,
idToConstant,
readerFactory,
SparkUtil::internalToSpark);
this.deleteFilter = deleteFilter;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ protected boolean supportsDefaultValues() {
return false;
}

protected boolean supportsNestedTypes() {
return true;
}

protected static final StructType SUPPORTED_PRIMITIVES =
StructType.of(
required(100, "id", LongType.get()),
Expand All @@ -74,6 +78,7 @@ protected boolean supportsDefaultValues() {
required(106, "d", Types.DoubleType.get()),
optional(107, "date", Types.DateType.get()),
required(108, "ts", Types.TimestampType.withZone()),
required(109, "ts_without_zone", Types.TimestampType.withoutZone()),
required(110, "s", Types.StringType.get()),
required(111, "uuid", Types.UUIDType.get()),
required(112, "fixed", Types.FixedType.ofLength(7)),
Expand Down Expand Up @@ -109,12 +114,16 @@ public void testStructWithOptionalFields() throws IOException {

@Test
public void testNestedStruct() throws IOException {
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

writeAndValidate(
TypeUtil.assignIncreasingFreshIds(new Schema(required(1, "struct", SUPPORTED_PRIMITIVES))));
}

@Test
public void testArray() throws IOException {
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

Schema schema =
new Schema(
required(0, "id", LongType.get()),
Expand All @@ -125,6 +134,8 @@ public void testArray() throws IOException {

@Test
public void testArrayOfStructs() throws IOException {
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

Schema schema =
TypeUtil.assignIncreasingFreshIds(
new Schema(
Expand All @@ -136,6 +147,8 @@ public void testArrayOfStructs() throws IOException {

@Test
public void testMap() throws IOException {
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

Schema schema =
new Schema(
required(0, "id", LongType.get()),
Expand All @@ -149,6 +162,8 @@ public void testMap() throws IOException {

@Test
public void testNumericMapKey() throws IOException {
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

Schema schema =
new Schema(
required(0, "id", LongType.get()),
Expand All @@ -160,6 +175,8 @@ public void testNumericMapKey() throws IOException {

@Test
public void testComplexMapKey() throws IOException {
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

Schema schema =
new Schema(
required(0, "id", LongType.get()),
Expand All @@ -179,6 +196,8 @@ public void testComplexMapKey() throws IOException {

@Test
public void testMapOfStructs() throws IOException {
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

Schema schema =
TypeUtil.assignIncreasingFreshIds(
new Schema(
Expand All @@ -193,6 +212,8 @@ public void testMapOfStructs() throws IOException {

@Test
public void testMixedTypes() throws IOException {
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

StructType structType =
StructType.of(
required(0, "id", LongType.get()),
Expand Down Expand Up @@ -248,17 +269,6 @@ public void testMixedTypes() throws IOException {
writeAndValidate(schema);
}

@Test
public void testTimestampWithoutZone() throws IOException {
Schema schema =
TypeUtil.assignIncreasingFreshIds(
new Schema(
required(0, "id", LongType.get()),
optional(1, "ts_without_zone", Types.TimestampType.withoutZone())));

writeAndValidate(schema);
}

@Test
public void testMissingRequiredWithoutDefault() {
Assumptions.assumeThat(supportsDefaultValues()).isTrue();
Expand Down Expand Up @@ -348,6 +358,7 @@ public void testNullDefaultValue() throws IOException {
@Test
public void testNestedDefaultValue() throws IOException {
Assumptions.assumeThat(supportsDefaultValues()).isTrue();
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

Schema writeSchema =
new Schema(
Expand Down Expand Up @@ -391,6 +402,7 @@ public void testNestedDefaultValue() throws IOException {
@Test
public void testMapNestedDefaultValue() throws IOException {
Assumptions.assumeThat(supportsDefaultValues()).isTrue();
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

Schema writeSchema =
new Schema(
Expand Down Expand Up @@ -443,6 +455,7 @@ public void testMapNestedDefaultValue() throws IOException {
@Test
public void testListNestedDefaultValue() throws IOException {
Assumptions.assumeThat(supportsDefaultValues()).isTrue();
Assumptions.assumeThat(supportsNestedTypes()).isTrue();

Schema writeSchema =
new Schema(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@
import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.TimestampNTZType;
import org.apache.spark.sql.types.TimestampType$;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.apache.spark.unsafe.types.UTF8String;
import scala.collection.Seq;
Expand Down Expand Up @@ -107,13 +109,25 @@ public static void assertEqualsSafe(Types.StructType struct, Record rec, Row row
public static void assertEqualsBatch(
Types.StructType struct, Iterator<Record> expected, ColumnarBatch batch) {
for (int rowId = 0; rowId < batch.numRows(); rowId++) {
List<Types.NestedField> fields = struct.fields();
InternalRow row = batch.getRow(rowId);
Record rec = expected.next();
for (int i = 0; i < fields.size(); i += 1) {
Type fieldType = fields.get(i).type();
Object expectedValue = rec.get(i);
Object actualValue = row.isNullAt(i) ? null : row.get(i, convert(fieldType));

List<Types.NestedField> fields = struct.fields();
for (int readPos = 0; readPos < fields.size(); readPos += 1) {
Types.NestedField field = fields.get(readPos);
Field writeField = rec.getSchema().getField(field.name());

Type fieldType = field.type();
Object actualValue = row.isNullAt(readPos) ? null : row.get(readPos, convert(fieldType));

Object expectedValue;
if (writeField != null) {
int writePos = writeField.pos();
expectedValue = rec.get(writePos);
} else {
expectedValue = field.initialDefault();
}

assertEqualsUnsafe(fieldType, expectedValue, actualValue);
}
}
Expand Down Expand Up @@ -751,6 +765,12 @@ private static void assertEquals(
for (int i = 0; i < actual.numFields(); i += 1) {
StructField field = struct.fields()[i];
DataType type = field.dataType();
// ColumnarRow.get doesn't support TimestampNTZType, causing tests to fail. the representation
// is identical to TimestampType so this uses that type to validate.
if (type instanceof TimestampNTZType) {
type = TimestampType$.MODULE$;
}

assertEquals(
context + "." + field.name(),
type,
Expand Down
Loading

0 comments on commit 7033667

Please sign in to comment.