Skip to content

Commit

Permalink
[spark] Support nested col pruning (apache#4269)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zouxxyy authored Oct 9, 2024
1 parent 05a0fab commit 20a3967
Show file tree
Hide file tree
Showing 12 changed files with 336 additions and 52 deletions.
19 changes: 19 additions & 0 deletions paimon-common/src/main/java/org/apache/paimon/types/ArrayType.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ public DataType getElementType() {
return elementType;
}

public DataType newElementType(DataType newElementType) {
return new ArrayType(isNullable(), newElementType);
}

@Override
public int defaultSize() {
return elementType.defaultSize();
Expand Down Expand Up @@ -96,6 +100,21 @@ public boolean equals(Object o) {
return elementType.equals(arrayType.elementType);
}

@Override
public boolean isPrunedFrom(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
ArrayType arrayType = (ArrayType) o;
return elementType.isPrunedFrom(arrayType.elementType);
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), elementType);
Expand Down
19 changes: 19 additions & 0 deletions paimon-common/src/main/java/org/apache/paimon/types/MapType.java
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ public DataType copy(boolean isNullable) {
return new MapType(isNullable, keyType.copy(), valueType.copy());
}

public DataType newKeyValueType(DataType newKeyType, DataType newValueType) {
return new MapType(isNullable(), newKeyType, newValueType);
}

@Override
public String asSQLString() {
return withNullability(FORMAT, keyType.asSQLString(), valueType.asSQLString());
Expand Down Expand Up @@ -105,6 +109,21 @@ public boolean equals(Object o) {
return keyType.equals(mapType.keyType) && valueType.equals(mapType.valueType);
}

@Override
public boolean isPrunedFrom(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
MapType mapType = (MapType) o;
return keyType.isPrunedFrom(mapType.keyType) && valueType.isPrunedFrom(mapType.valueType);
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), keyType, valueType);
Expand Down
18 changes: 18 additions & 0 deletions paimon-common/src/main/java/org/apache/paimon/types/RowType.java
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,15 @@ public boolean containsField(String fieldName) {
return false;
}

public boolean containsField(int fieldId) {
for (DataField field : fields) {
if (field.id() == fieldId) {
return true;
}
}
return false;
}

public boolean notContainsField(String fieldName) {
return !containsField(fieldName);
}
Expand All @@ -136,6 +145,15 @@ public DataField getField(String fieldName) {
throw new RuntimeException("Cannot find field: " + fieldName);
}

public DataField getField(int fieldId) {
for (DataField field : fields) {
if (field.id() == fieldId) {
return field;
}
}
throw new RuntimeException("Cannot find field by field id: " + fieldId);
}

public int getFieldIndexByFieldId(int fieldId) {
for (int i = 0; i < fields.size(); i++) {
if (fields.get(i).id() == fieldId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
import org.apache.paimon.schema.IndexCastMapping;
import org.apache.paimon.schema.SchemaEvolutionUtil;
import org.apache.paimon.schema.TableSchema;
import org.apache.paimon.types.ArrayType;
import org.apache.paimon.types.DataField;
import org.apache.paimon.types.DataType;
import org.apache.paimon.types.MapType;
import org.apache.paimon.types.RowType;

import javax.annotation.Nullable;
Expand All @@ -37,7 +40,6 @@
import java.util.function.Function;

import static org.apache.paimon.predicate.PredicateBuilder.excludePredicateWithFields;
import static org.apache.paimon.utils.Preconditions.checkState;

/** Class with index mapping and bulk format. */
public class BulkFormatMapping {
Expand Down Expand Up @@ -152,27 +154,48 @@ private List<DataField> readDataFields(TableSchema dataSchema) {
.filter(f -> f.id() == dataField.id())
.findFirst()
.ifPresent(
f -> {
if (f.type() instanceof RowType) {
RowType tableFieldType = (RowType) f.type();
RowType dataFieldType = (RowType) dataField.type();
checkState(tableFieldType.isPrunedFrom(dataFieldType));
// Since the nested type schema evolution is not supported,
// directly copy the fields from tableField's type to
// dataField's type.
// todo: support nested type schema evolutions.
field ->
readDataFields.add(
dataField.newType(
dataFieldType.copy(
tableFieldType.getFields())));
} else {
readDataFields.add(dataField);
}
});
pruneDataType(
field.type(), dataField.type()))));
}
return readDataFields;
}

private DataType pruneDataType(DataType readType, DataType dataType) {
switch (readType.getTypeRoot()) {
case ROW:
RowType r = (RowType) readType;
RowType d = (RowType) dataType;
ArrayList<DataField> newFields = new ArrayList<>();
for (DataField rf : r.getFields()) {
if (d.containsField(rf.id())) {
DataField df = d.getField(rf.id());
newFields.add(df.newType(pruneDataType(rf.type(), df.type())));
}
}
return d.copy(newFields);
case MAP:
return ((MapType) dataType)
.newKeyValueType(
pruneDataType(
((MapType) readType).getKeyType(),
((MapType) dataType).getKeyType()),
pruneDataType(
((MapType) readType).getValueType(),
((MapType) dataType).getValueType()));
case ARRAY:
return ((ArrayType) dataType)
.newElementType(
pruneDataType(
((ArrayType) readType).getElementType(),
((ArrayType) dataType).getElementType()));
default:
return dataType;
}
}

private List<Predicate> readFilters(
List<Predicate> filters, TableSchema tableSchema, TableSchema dataSchema) {
List<Predicate> dataFilters =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
import org.apache.paimon.options.Options;
import org.apache.paimon.reader.RecordReader;
import org.apache.paimon.reader.RecordReader.RecordIterator;
import org.apache.paimon.types.ArrayType;
import org.apache.paimon.types.DataField;
import org.apache.paimon.types.DataType;
import org.apache.paimon.types.MapType;
import org.apache.paimon.types.RowType;
import org.apache.paimon.utils.Pool;

Expand All @@ -45,6 +48,7 @@
import org.apache.parquet.hadoop.ParquetInputFormat;
import org.apache.parquet.io.ColumnIOFactory;
import org.apache.parquet.io.MessageColumnIO;
import org.apache.parquet.schema.ConversionPatterns;
import org.apache.parquet.schema.GroupType;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.Type;
Expand All @@ -55,11 +59,17 @@
import javax.annotation.Nullable;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import static org.apache.paimon.format.parquet.ParquetSchemaConverter.LIST_ELEMENT_NAME;
import static org.apache.paimon.format.parquet.ParquetSchemaConverter.LIST_NAME;
import static org.apache.paimon.format.parquet.ParquetSchemaConverter.MAP_KEY_NAME;
import static org.apache.paimon.format.parquet.ParquetSchemaConverter.MAP_REPEATED_NAME;
import static org.apache.paimon.format.parquet.ParquetSchemaConverter.MAP_VALUE_NAME;
import static org.apache.paimon.format.parquet.reader.ParquetSplitReaderUtil.buildFieldsList;
import static org.apache.paimon.format.parquet.reader.ParquetSplitReaderUtil.createColumnReader;
import static org.apache.paimon.format.parquet.reader.ParquetSplitReaderUtil.createWritableColumnVector;
Expand Down Expand Up @@ -155,13 +165,59 @@ private MessageType clipParquetSchema(GroupType parquetSchema) {
ParquetSchemaConverter.convertToParquetType(fieldName, projectedTypes[i]);
unknownFieldsIndices.add(i);
} else {
types[i] = parquetSchema.getType(fieldName);
Type parquetType = parquetSchema.getType(fieldName);
types[i] = clipParquetType(projectedTypes[i], parquetType);
}
}

return Types.buildMessage().addFields(types).named("paimon-parquet");
}

/** Clips `parquetType` by `readType`. */
private Type clipParquetType(DataType readType, Type parquetType) {
switch (readType.getTypeRoot()) {
case ROW:
RowType rowType = (RowType) readType;
GroupType rowGroup = (GroupType) parquetType;
List<Type> rowGroupFields = new ArrayList<>();
for (DataField field : rowType.getFields()) {
String fieldName = field.name();
if (rowGroup.containsField(fieldName)) {
Type type = rowGroup.getType(fieldName);
rowGroupFields.add(clipParquetType(field.type(), type));
} else {
// todo: support nested field missing
throw new RuntimeException("field " + fieldName + " is missing");
}
}
return rowGroup.withNewFields(rowGroupFields);
case MAP:
MapType mapType = (MapType) readType;
GroupType mapGroup = (GroupType) parquetType;
GroupType keyValue = mapGroup.getType(MAP_REPEATED_NAME).asGroupType();
return ConversionPatterns.mapType(
mapGroup.getRepetition(),
mapGroup.getName(),
MAP_REPEATED_NAME,
clipParquetType(mapType.getKeyType(), keyValue.getType(MAP_KEY_NAME)),
keyValue.containsField(MAP_VALUE_NAME)
? clipParquetType(
mapType.getValueType(), keyValue.getType(MAP_VALUE_NAME))
: null);
case ARRAY:
ArrayType arrayType = (ArrayType) readType;
GroupType arrayGroup = (GroupType) parquetType;
GroupType list = arrayGroup.getType(LIST_NAME).asGroupType();
return ConversionPatterns.listOfElements(
arrayGroup.getRepetition(),
arrayGroup.getName(),
clipParquetType(
arrayType.getElementType(), list.getType(LIST_ELEMENT_NAME)));
default:
return parquetType;
}
}

private void checkSchema(MessageType fileSchema, MessageType requestedSchema)
throws IOException, UnsupportedOperationException {
if (projectedFields.length != requestedSchema.getFieldCount()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
public class ParquetSchemaConverter {

static final String MAP_REPEATED_NAME = "key_value";
static final String MAP_KEY_NAME = "key";
static final String MAP_VALUE_NAME = "value";
static final String LIST_NAME = "list";
static final String LIST_ELEMENT_NAME = "element";

public static MessageType convertToParquetMessageType(String name, RowType rowType) {
Expand Down Expand Up @@ -149,8 +152,8 @@ private static Type convertToParquetType(
repetition,
name,
MAP_REPEATED_NAME,
convertToParquetType("key", keyType),
convertToParquetType("value", mapType.getValueType()));
convertToParquetType(MAP_KEY_NAME, keyType),
convertToParquetType(MAP_VALUE_NAME, mapType.getValueType()));
case MULTISET:
MultisetType multisetType = (MultisetType) type;
DataType elementType = multisetType.getElementType();
Expand All @@ -163,8 +166,8 @@ private static Type convertToParquetType(
repetition,
name,
MAP_REPEATED_NAME,
convertToParquetType("key", elementType),
convertToParquetType("value", new IntType(false)));
convertToParquetType(MAP_KEY_NAME, elementType),
convertToParquetType(MAP_VALUE_NAME, new IntType(false)));
case ROW:
RowType rowType = (RowType) type;
return new GroupType(repetition, name, convertToParquetTypes(rowType));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,12 @@ private static List<ColumnDescriptor> getAllColumnDescriptorByType(
}

public static List<ParquetField> buildFieldsList(
List<DataField> childrens, List<String> fieldNames, MessageColumnIO columnIO) {
List<DataField> children, List<String> fieldNames, MessageColumnIO columnIO) {
List<ParquetField> list = new ArrayList<>();
for (int i = 0; i < childrens.size(); i++) {
for (int i = 0; i < children.size(); i++) {
list.add(
constructField(
childrens.get(i), lookupColumnByName(columnIO, fieldNames.get(i))));
children.get(i), lookupColumnByName(columnIO, fieldNames.get(i))));
}
return list;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,42 @@ public static org.apache.paimon.types.DataType toPaimonType(DataType dataType) {
return SparkToPaimonTypeVisitor.visit(dataType);
}

/**
* Prune Paimon `RowType` by required Spark `StructType`, use this method instead of {@link
* #toPaimonType(DataType)} when need to retain the field id.
*/
public static RowType prunePaimonRowType(StructType requiredStructType, RowType rowType) {
return (RowType) prunePaimonType(requiredStructType, rowType);
}

private static org.apache.paimon.types.DataType prunePaimonType(
DataType sparkDataType, org.apache.paimon.types.DataType paimonDataType) {
if (sparkDataType instanceof StructType) {
StructType s = (StructType) sparkDataType;
RowType p = (RowType) paimonDataType;
List<DataField> newFields = new ArrayList<>();
for (StructField field : s.fields()) {
DataField f = p.getField(field.name());
newFields.add(f.newType(prunePaimonType(field.dataType(), f.type())));
}
return p.copy(newFields);
} else if (sparkDataType instanceof org.apache.spark.sql.types.MapType) {
org.apache.spark.sql.types.MapType s =
(org.apache.spark.sql.types.MapType) sparkDataType;
MapType p = (MapType) paimonDataType;
return p.newKeyValueType(
prunePaimonType(s.keyType(), p.getKeyType()),
prunePaimonType(s.valueType(), p.getValueType()));
} else if (sparkDataType instanceof org.apache.spark.sql.types.ArrayType) {
org.apache.spark.sql.types.ArrayType s =
(org.apache.spark.sql.types.ArrayType) sparkDataType;
ArrayType r = (ArrayType) paimonDataType;
return r.newElementType(prunePaimonType(s.elementType(), r.getElementType()));
} else {
return paimonDataType;
}
}

private static class PaimonToSparkTypeVisitor extends DataTypeDefaultVisitor<DataType> {

private static final PaimonToSparkTypeVisitor INSTANCE = new PaimonToSparkTypeVisitor();
Expand Down
Loading

0 comments on commit 20a3967

Please sign in to comment.