diff --git a/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetReaderFactory.java b/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetReaderFactory.java index b0715bb5389de..761bef90f3914 100644 --- a/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetReaderFactory.java +++ b/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetReaderFactory.java @@ -33,6 +33,7 @@ import org.apache.paimon.options.Options; import org.apache.paimon.reader.RecordReader; import org.apache.paimon.reader.RecordReader.RecordIterator; +import org.apache.paimon.types.DataField; import org.apache.paimon.types.DataType; import org.apache.paimon.types.RowType; import org.apache.paimon.utils.Pool; @@ -155,13 +156,37 @@ private MessageType clipParquetSchema(GroupType parquetSchema) { ParquetSchemaConverter.convertToParquetType(fieldName, projectedTypes[i]); unknownFieldsIndices.add(i); } else { - types[i] = parquetSchema.getType(fieldName); + Type type = parquetSchema.getType(fieldName); + if (type instanceof GroupType && projectedTypes[i] instanceof RowType) { + type = clipParquetGroup((GroupType) type, (RowType) projectedTypes[i]); + } + types[i] = type; } } return Types.buildMessage().addFields(types).named("paimon-parquet"); } + /** Clips `parquetGroup` by `requestedRowType`. */ + private GroupType clipParquetGroup(GroupType parquetGroup, RowType requestedRowType) { + Types.GroupBuilder builder = Types.buildGroup(parquetGroup.getRepetition()); + List fields = requestedRowType.getFields(); + for (DataField field : fields) { + String fieldName = field.name(); + if (parquetGroup.containsField(fieldName)) { + Type type = parquetGroup.getType(fieldName); + if (type instanceof GroupType && field.type() instanceof RowType) { + type = clipParquetGroup((GroupType) type, (RowType) field.type()); + } + builder.addField(type); + } else { + // todo: support nested field missing + throw new RuntimeException("field " + fieldName + " is missing"); + } + } + return builder.named(parquetGroup.getName()); + } + private void checkSchema(MessageType fileSchema, MessageType requestedSchema) throws IOException, UnsupportedOperationException { if (projectedFields.length != requestedSchema.getFieldCount()) { diff --git a/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/ParquetSplitReaderUtil.java b/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/ParquetSplitReaderUtil.java index 90abaa992c175..860ec54fa88b0 100644 --- a/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/ParquetSplitReaderUtil.java +++ b/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/ParquetSplitReaderUtil.java @@ -370,12 +370,12 @@ private static List getAllColumnDescriptorByType( } public static List buildFieldsList( - List childrens, List fieldNames, MessageColumnIO columnIO) { + List children, List fieldNames, MessageColumnIO columnIO) { List list = new ArrayList<>(); - for (int i = 0; i < childrens.size(); i++) { + for (int i = 0; i < children.size(); i++) { list.add( constructField( - childrens.get(i), lookupColumnByName(columnIO, fieldNames.get(i)))); + children.get(i), lookupColumnByName(columnIO, fieldNames.get(i)))); } return list; } diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkTypeUtils.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkTypeUtils.java index bc5eed00e98b9..4f0b3310a1e30 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkTypeUtils.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkTypeUtils.java @@ -81,6 +81,27 @@ public static org.apache.paimon.types.DataType toPaimonType(DataType dataType) { return SparkToPaimonTypeVisitor.visit(dataType); } + /** + * Project the specified schema from Spark `StructType` to a Paimon `RowType`, use this method + * instead of {@link #toPaimonType(DataType)} when need to retain the field id. + */ + public static RowType projectToPaimonType(StructType requiredType, RowType rowType) { + List fields = new ArrayList<>(); + for (StructField sparkField : requiredType.fields()) { + DataField paimonField = rowType.getField(sparkField.name()); + if (sparkField.dataType() instanceof StructType) { + fields.add( + paimonField.newType( + projectToPaimonType( + (StructType) sparkField.dataType(), + (RowType) paimonField.type()))); + } else { + fields.add(paimonField); + } + } + return rowType.copy(fields); + } + private static class PaimonToSparkTypeVisitor extends DataTypeDefaultVisitor { private static final PaimonToSparkTypeVisitor INSTANCE = new PaimonToSparkTypeVisitor(); diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/ColumnPruningAndPushDown.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/ColumnPruningAndPushDown.scala index c36c2fff2ca91..346e979b9daf4 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/ColumnPruningAndPushDown.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/ColumnPruningAndPushDown.scala @@ -23,6 +23,7 @@ import org.apache.paimon.spark.schema.PaimonMetadataColumn import org.apache.paimon.table.Table import org.apache.paimon.table.source.ReadBuilder import org.apache.paimon.types.RowType +import org.apache.paimon.utils.Preconditions.checkState import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.types.StructType @@ -33,28 +34,27 @@ trait ColumnPruningAndPushDown extends Scan { def filters: Seq[Predicate] def pushDownLimit: Option[Int] = None - val tableRowType: RowType = table.rowType - val tableSchema: StructType = SparkTypeUtils.fromPaimonRowType(tableRowType) + lazy val tableRowType: RowType = table.rowType + lazy val tableSchema: StructType = SparkTypeUtils.fromPaimonRowType(tableRowType) final def partitionType: StructType = { SparkTypeUtils.toSparkPartitionType(table) } private[paimon] val (requiredTableFields, metadataFields) = { - val nameToField = tableSchema.map(field => (field.name, field)).toMap - val _tableFields = requiredSchema.flatMap(field => nameToField.get(field.name)) - val _metadataFields = - requiredSchema - .filterNot(field => tableSchema.fieldNames.contains(field.name)) - .filter(field => PaimonMetadataColumn.SUPPORTED_METADATA_COLUMNS.contains(field.name)) - (_tableFields, _metadataFields) + checkState( + requiredSchema.fields.forall( + field => + tableRowType.containsField(field.name) || + PaimonMetadataColumn.SUPPORTED_METADATA_COLUMNS.contains(field.name))) + requiredSchema.fields.partition(field => tableRowType.containsField(field.name)) } + lazy val requiredTableRowType: RowType = + SparkTypeUtils.projectToPaimonType(StructType(requiredTableFields), tableRowType) + lazy val readBuilder: ReadBuilder = { - val _readBuilder = table.newReadBuilder() - val projection = - requiredTableFields.map(field => tableSchema.fieldNames.indexOf(field.name)).toArray - _readBuilder.withProjection(projection) + val _readBuilder = table.newReadBuilder().withReadType(requiredTableRowType) if (filters.nonEmpty) { val pushedPredicate = PredicateBuilder.and(filters: _*) _readBuilder.withFilter(pushedPredicate) @@ -68,6 +68,6 @@ trait ColumnPruningAndPushDown extends Scan { } override def readSchema(): StructType = { - StructType(requiredTableFields ++ metadataFields) + requiredSchema } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala index 963d9fadd2979..1e5483dcf2c88 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala @@ -18,7 +18,6 @@ package org.apache.paimon.spark -import org.apache.paimon.spark.schema.PaimonMetadataColumn import org.apache.paimon.stats.ColStats import org.apache.paimon.types.{DataField, DataType, RowType} @@ -65,19 +64,13 @@ case class PaimonStatistics[T <: PaimonBaseScan](scan: T) extends Statistics { val wholeSchemaSize = getSizeForRow(scan.tableRowType) - val requiredDataSchemaSize = scan.requiredTableFields.map { - field => - val dataField = scan.tableRowType.getField(field.name) - getSizeForField(dataField) - }.sum + val requiredDataSchemaSize = + scan.requiredTableRowType.getFields.asScala.map(field => getSizeForField(field)).sum val requiredDataSizeInBytes = paimonStats.mergedRecordSize().getAsLong * (requiredDataSchemaSize.toDouble / wholeSchemaSize) - val metadataSchemaSize = scan.metadataFields.map { - field => - val dataField = PaimonMetadataColumn.get(field.name, scan.partitionType).toPaimonDataField - getSizeForField(dataField) - }.sum + val metadataSchemaSize = + scan.metadataColumns.map(field => getSizeForField(field.toPaimonDataField)).sum val metadataSizeInBytes = paimonStats.mergedRecordCount().getAsLong * metadataSchemaSize val sizeInBytes = (requiredDataSizeInBytes + metadataSizeInBytes).toLong diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonQueryTest.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonQueryTest.scala index c1814096fb7d6..1ba69ba9ef2ec 100644 --- a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonQueryTest.scala +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonQueryTest.scala @@ -218,6 +218,45 @@ class PaimonQueryTest extends PaimonSparkTestBase { } } + test("Paimon Query: query nested cols") { + fileFormats.foreach { + fileFormat => + bucketModes.foreach { + bucketMode => + val bucketProp = if (bucketMode != -1) { + s", 'bucket-key'='name', 'bucket' = '$bucketMode' " + } else { + "" + } + withTable("students") { + sql(s""" + |CREATE TABLE students ( + | name STRING, + | age INT, + | course STRUCT, + | teacher STRUCT> + |) USING paimon + |TBLPROPERTIES ('file.format'='$fileFormat' $bucketProp); + |""".stripMargin) + + sql("INSERT INTO students VALUES ('Alice', 20, STRUCT('Math', 85.0), STRUCT('John', STRUCT('Street 1', 'City 1')))") + sql("INSERT INTO students VALUES ('Bob', 22, STRUCT('Biology', 92.0), STRUCT('Jane', STRUCT('Street 2', 'City 2')))") + sql("INSERT INTO students VALUES ('Cathy', 21, STRUCT('History', 95.0), STRUCT('Jane', STRUCT('Street 3', 'City 3')))") + + checkAnswer( + sql( + "SELECT course.grade, name, teacher.address, course.course_name FROM students ORDER BY name"), + Seq( + Row(85.0, "Alice", Row("Street 1", "City 1"), "Math"), + Row(92.0, "Bob", Row("Street 2", "City 2"), "Biology"), + Row(95.0, "Cathy", Row("Street 3", "City 3"), "History") + ) + ) + } + } + } + } + private def getAllFiles( tableName: String, partitions: Seq[String],