Skip to content

Commit

Permalink
1
Browse files Browse the repository at this point in the history
  • Loading branch information
Zouxxyy committed Sep 26, 2024
1 parent 703a2a2 commit c5711df
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.paimon.options.Options;
import org.apache.paimon.reader.RecordReader;
import org.apache.paimon.reader.RecordReader.RecordIterator;
import org.apache.paimon.types.DataField;
import org.apache.paimon.types.DataType;
import org.apache.paimon.types.RowType;
import org.apache.paimon.utils.Pool;
Expand Down Expand Up @@ -155,13 +156,37 @@ private MessageType clipParquetSchema(GroupType parquetSchema) {
ParquetSchemaConverter.convertToParquetType(fieldName, projectedTypes[i]);
unknownFieldsIndices.add(i);
} else {
types[i] = parquetSchema.getType(fieldName);
Type type = parquetSchema.getType(fieldName);
if (type instanceof GroupType && projectedTypes[i] instanceof RowType) {
type = clipParquetGroup((GroupType) type, (RowType) projectedTypes[i]);
}
types[i] = type;
}
}

return Types.buildMessage().addFields(types).named("paimon-parquet");
}

/** Clips `parquetGroup` by `requestedRowType`. */
private GroupType clipParquetGroup(GroupType parquetGroup, RowType requestedRowType) {
Types.GroupBuilder<GroupType> builder = Types.buildGroup(parquetGroup.getRepetition());
List<DataField> fields = requestedRowType.getFields();
for (DataField field : fields) {
String fieldName = field.name();
if (parquetGroup.containsField(fieldName)) {
Type type = parquetGroup.getType(fieldName);
if (type instanceof GroupType && field.type() instanceof RowType) {
type = clipParquetGroup((GroupType) type, (RowType) field.type());
}
builder.addField(type);
} else {
// todo: support nested field missing
throw new RuntimeException("field " + fieldName + " is missing");
}
}
return builder.named(parquetGroup.getName());
}

private void checkSchema(MessageType fileSchema, MessageType requestedSchema)
throws IOException, UnsupportedOperationException {
if (projectedFields.length != requestedSchema.getFieldCount()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,12 @@ private static List<ColumnDescriptor> getAllColumnDescriptorByType(
}

public static List<ParquetField> buildFieldsList(
List<DataField> childrens, List<String> fieldNames, MessageColumnIO columnIO) {
List<DataField> children, List<String> fieldNames, MessageColumnIO columnIO) {
List<ParquetField> list = new ArrayList<>();
for (int i = 0; i < childrens.size(); i++) {
for (int i = 0; i < children.size(); i++) {
list.add(
constructField(
childrens.get(i), lookupColumnByName(columnIO, fieldNames.get(i))));
children.get(i), lookupColumnByName(columnIO, fieldNames.get(i))));
}
return list;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,27 @@ public static org.apache.paimon.types.DataType toPaimonType(DataType dataType) {
return SparkToPaimonTypeVisitor.visit(dataType);
}

/**
* Project the specified schema from Spark `StructType` to a Paimon `RowType`, use this method
* instead of {@link #toPaimonType(DataType)} when need to retain the field id.
*/
public static RowType projectToPaimonType(StructType requiredType, RowType rowType) {
List<DataField> fields = new ArrayList<>();
for (StructField sparkField : requiredType.fields()) {
DataField paimonField = rowType.getField(sparkField.name());
if (sparkField.dataType() instanceof StructType) {
fields.add(
paimonField.newType(
projectToPaimonType(
(StructType) sparkField.dataType(),
(RowType) paimonField.type())));
} else {
fields.add(paimonField);
}
}
return rowType.copy(fields);
}

private static class PaimonToSparkTypeVisitor extends DataTypeDefaultVisitor<DataType> {

private static final PaimonToSparkTypeVisitor INSTANCE = new PaimonToSparkTypeVisitor();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.paimon.spark.schema.PaimonMetadataColumn
import org.apache.paimon.table.Table
import org.apache.paimon.table.source.ReadBuilder
import org.apache.paimon.types.RowType
import org.apache.paimon.utils.Preconditions.checkState

import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.types.StructType
Expand All @@ -33,28 +34,27 @@ trait ColumnPruningAndPushDown extends Scan {
def filters: Seq[Predicate]
def pushDownLimit: Option[Int] = None

val tableRowType: RowType = table.rowType
val tableSchema: StructType = SparkTypeUtils.fromPaimonRowType(tableRowType)
lazy val tableRowType: RowType = table.rowType
lazy val tableSchema: StructType = SparkTypeUtils.fromPaimonRowType(tableRowType)

final def partitionType: StructType = {
SparkTypeUtils.toSparkPartitionType(table)
}

private[paimon] val (requiredTableFields, metadataFields) = {
val nameToField = tableSchema.map(field => (field.name, field)).toMap
val _tableFields = requiredSchema.flatMap(field => nameToField.get(field.name))
val _metadataFields =
requiredSchema
.filterNot(field => tableSchema.fieldNames.contains(field.name))
.filter(field => PaimonMetadataColumn.SUPPORTED_METADATA_COLUMNS.contains(field.name))
(_tableFields, _metadataFields)
checkState(
requiredSchema.fields.forall(
field =>
tableRowType.containsField(field.name) ||
PaimonMetadataColumn.SUPPORTED_METADATA_COLUMNS.contains(field.name)))
requiredSchema.fields.partition(field => tableRowType.containsField(field.name))
}

lazy val requiredTableRowType: RowType =
SparkTypeUtils.projectToPaimonType(StructType(requiredTableFields), tableRowType)

lazy val readBuilder: ReadBuilder = {
val _readBuilder = table.newReadBuilder()
val projection =
requiredTableFields.map(field => tableSchema.fieldNames.indexOf(field.name)).toArray
_readBuilder.withProjection(projection)
val _readBuilder = table.newReadBuilder().withReadType(requiredTableRowType)
if (filters.nonEmpty) {
val pushedPredicate = PredicateBuilder.and(filters: _*)
_readBuilder.withFilter(pushedPredicate)
Expand All @@ -68,6 +68,6 @@ trait ColumnPruningAndPushDown extends Scan {
}

override def readSchema(): StructType = {
StructType(requiredTableFields ++ metadataFields)
requiredSchema
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

package org.apache.paimon.spark

import org.apache.paimon.spark.schema.PaimonMetadataColumn
import org.apache.paimon.stats.ColStats
import org.apache.paimon.types.{DataField, DataType, RowType}

Expand Down Expand Up @@ -65,19 +64,13 @@ case class PaimonStatistics[T <: PaimonBaseScan](scan: T) extends Statistics {

val wholeSchemaSize = getSizeForRow(scan.tableRowType)

val requiredDataSchemaSize = scan.requiredTableFields.map {
field =>
val dataField = scan.tableRowType.getField(field.name)
getSizeForField(dataField)
}.sum
val requiredDataSchemaSize =
scan.requiredTableRowType.getFields.asScala.map(field => getSizeForField(field)).sum
val requiredDataSizeInBytes =
paimonStats.mergedRecordSize().getAsLong * (requiredDataSchemaSize.toDouble / wholeSchemaSize)

val metadataSchemaSize = scan.metadataFields.map {
field =>
val dataField = PaimonMetadataColumn.get(field.name, scan.partitionType).toPaimonDataField
getSizeForField(dataField)
}.sum
val metadataSchemaSize =
scan.metadataColumns.map(field => getSizeForField(field.toPaimonDataField)).sum
val metadataSizeInBytes = paimonStats.mergedRecordCount().getAsLong * metadataSchemaSize

val sizeInBytes = (requiredDataSizeInBytes + metadataSizeInBytes).toLong
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,45 @@ class PaimonQueryTest extends PaimonSparkTestBase {
}
}

test("Paimon Query: query nested cols") {
fileFormats.foreach {
fileFormat =>
bucketModes.foreach {
bucketMode =>
val bucketProp = if (bucketMode != -1) {
s", 'bucket-key'='name', 'bucket' = '$bucketMode' "
} else {
""
}
withTable("students") {
sql(s"""
|CREATE TABLE students (
| name STRING,
| age INT,
| course STRUCT<course_name: STRING, grade: DOUBLE>,
| teacher STRUCT<name: STRING, address: STRUCT<street: STRING, city: STRING>>
|) USING paimon
|TBLPROPERTIES ('file.format'='$fileFormat' $bucketProp);
|""".stripMargin)

sql("INSERT INTO students VALUES ('Alice', 20, STRUCT('Math', 85.0), STRUCT('John', STRUCT('Street 1', 'City 1')))")
sql("INSERT INTO students VALUES ('Bob', 22, STRUCT('Biology', 92.0), STRUCT('Jane', STRUCT('Street 2', 'City 2')))")
sql("INSERT INTO students VALUES ('Cathy', 21, STRUCT('History', 95.0), STRUCT('Jane', STRUCT('Street 3', 'City 3')))")

checkAnswer(
sql(
"SELECT course.grade, name, teacher.address, course.course_name FROM students ORDER BY name"),
Seq(
Row(85.0, "Alice", Row("Street 1", "City 1"), "Math"),
Row(92.0, "Bob", Row("Street 2", "City 2"), "Biology"),
Row(95.0, "Cathy", Row("Street 3", "City 3"), "History")
)
)
}
}
}
}

private def getAllFiles(
tableName: String,
partitions: Seq[String],
Expand Down

0 comments on commit c5711df

Please sign in to comment.