From 718680fbf760991faa4dfc058b893c2ad6f9c7f3 Mon Sep 17 00:00:00 2001 From: Zouxxyy Date: Tue, 24 Dec 2024 12:17:21 +0800 Subject: [PATCH] [spark] Integrate Variant with Spark4 (#4764) --- .../apache/paimon/spark/sql/VariantTest.scala | 21 +++ .../org/apache/paimon/spark/SparkRow.java | 9 +- .../apache/paimon/spark/SparkTypeUtils.java | 9 ++ .../spark/sql/paimon/shims/SparkShim.scala | 9 +- .../paimon/spark/sql/VariantTestBase.scala | 133 ++++++++++++++++++ .../spark/sql/paimon/shims/Spark3Shim.scala | 8 ++ .../paimon/spark/data/Spark4ArrayData.scala | 6 +- .../paimon/spark/data/Spark4InternalRow.scala | 6 +- .../spark/sql/paimon/shims/Spark4Shim.scala | 14 +- 9 files changed, 206 insertions(+), 9 deletions(-) create mode 100644 paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/VariantTest.scala create mode 100644 paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/VariantTestBase.scala diff --git a/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/VariantTest.scala b/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/VariantTest.scala new file mode 100644 index 000000000000..aafd1dc4b967 --- /dev/null +++ b/paimon-spark/paimon-spark-4.0/src/test/scala/org/apache/paimon/spark/sql/VariantTest.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +class VariantTest extends VariantTestBase {} diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkRow.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkRow.java index 1ad50286a2af..7d0d8ceb22a6 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkRow.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkRow.java @@ -35,6 +35,7 @@ import org.apache.paimon.utils.DateTimeUtils; import org.apache.spark.sql.Row; +import org.apache.spark.sql.paimon.shims.SparkShimLoader; import java.io.Serializable; import java.sql.Date; @@ -145,8 +146,8 @@ public byte[] getBinary(int i) { } @Override - public Variant getVariant(int pos) { - throw new UnsupportedOperationException(); + public Variant getVariant(int i) { + return SparkShimLoader.getSparkShim().toPaimonVariant(row.getAs(i)); } @Override @@ -307,8 +308,8 @@ public byte[] getBinary(int i) { } @Override - public Variant getVariant(int pos) { - throw new UnsupportedOperationException(); + public Variant getVariant(int i) { + return SparkShimLoader.getSparkShim().toPaimonVariant(getAs(i)); } @Override diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkTypeUtils.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkTypeUtils.java index f6643f758406..f72924edce42 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkTypeUtils.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkTypeUtils.java @@ -42,7 +42,9 @@ import org.apache.paimon.types.TinyIntType; import org.apache.paimon.types.VarBinaryType; import org.apache.paimon.types.VarCharType; +import org.apache.paimon.types.VariantType; +import org.apache.spark.sql.paimon.shims.SparkShimLoader; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.LongType; @@ -217,6 +219,11 @@ public DataType visit(LocalZonedTimestampType localZonedTimestampType) { return DataTypes.TimestampType; } + @Override + public DataType visit(VariantType variantType) { + return SparkShimLoader.getSparkShim().SparkVariantType(); + } + @Override public DataType visit(ArrayType arrayType) { org.apache.paimon.types.DataType elementType = arrayType.getElementType(); @@ -381,6 +388,8 @@ public org.apache.paimon.types.DataType atomic(DataType atomic) { } else if (atomic instanceof org.apache.spark.sql.types.TimestampNTZType) { // Move TimestampNTZType to the end for compatibility with spark3.3 and below return new TimestampType(); + } else if (SparkShimLoader.getSparkShim().isSparkVariantType(atomic)) { + return new VariantType(); } throw new UnsupportedOperationException( diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/paimon/shims/SparkShim.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/paimon/shims/SparkShim.scala index 334bd6e93180..3d29d7c3c577 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/paimon/shims/SparkShim.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/paimon/shims/SparkShim.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.paimon.shims +import org.apache.paimon.data.variant.Variant import org.apache.paimon.spark.data.{SparkArrayData, SparkInternalRow} import org.apache.paimon.types.{DataType, RowType} @@ -33,7 +34,7 @@ import org.apache.spark.sql.types.StructType import java.util.{Map => JMap} /** - * A spark shim trait. It declare methods which have incompatible implementations between Spark 3 + * A spark shim trait. It declares methods which have incompatible implementations between Spark 3 * and Spark 4. The specific SparkShim implementation will be loaded through Service Provider * Interface. */ @@ -62,4 +63,10 @@ trait SparkShim { def convertToExpression(spark: SparkSession, column: Column): Expression + // for variant + def toPaimonVariant(o: Object): Variant + + def isSparkVariantType(dataType: org.apache.spark.sql.types.DataType): Boolean + + def SparkVariantType(): org.apache.spark.sql.types.DataType } diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/VariantTestBase.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/VariantTestBase.scala new file mode 100644 index 000000000000..aeb9a9605ec6 --- /dev/null +++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/VariantTestBase.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.spark.sql + +import org.apache.paimon.spark.PaimonSparkTestBase + +import org.apache.spark.sql.Row + +abstract class VariantTestBase extends PaimonSparkTestBase { + + test("Paimon Variant: read and write variant") { + sql("CREATE TABLE T (id INT, v VARIANT)") + sql(""" + |INSERT INTO T VALUES + | (1, parse_json('{"age":26,"city":"Beijing"}')), + | (2, parse_json('{"age":27,"city":"Hangzhou"}')) + | """.stripMargin) + + checkAnswer( + sql( + "SELECT id, variant_get(v, '$.age', 'int'), variant_get(v, '$.city', 'string') FROM T ORDER BY id"), + Seq(Row(1, 26, "Beijing"), Row(2, 27, "Hangzhou")) + ) + checkAnswer( + sql( + "SELECT variant_get(v, '$.city', 'string') FROM T WHERE variant_get(v, '$.age', 'int') == 26"), + Seq(Row("Beijing")) + ) + checkAnswer( + sql("SELECT * FROM T WHERE variant_get(v, '$.age', 'int') == 27"), + sql("""SELECT 2, parse_json('{"age":27,"city":"Hangzhou"}')""") + ) + } + + test("Paimon Variant: read and write array variant") { + sql("CREATE TABLE T (id INT, v ARRAY)") + sql( + """ + |INSERT INTO T VALUES + | (1, array(parse_json('{"age":26,"city":"Beijing"}'), parse_json('{"age":27,"city":"Hangzhou"}'))), + | (2, array(parse_json('{"age":27,"city":"Shanghai"}'))) + | """.stripMargin) + + withSparkSQLConf("spark.sql.ansi.enabled" -> "false") { + checkAnswer( + sql( + "SELECT id, variant_get(v[1], '$.age', 'int'), variant_get(v[0], '$.city', 'string') FROM T ORDER BY id"), + Seq(Row(1, 27, "Beijing"), Row(2, null, "Shanghai")) + ) + } + } + + test("Paimon Variant: complex json") { + val json = + """ + |{ + | "object" : { + | "name" : "Apache Paimon", + | "age" : 2, + | "address" : { + | "street" : "Main St", + | "city" : "Hangzhou" + | } + | }, + | "array" : [ 1, 2, 3, 4, 5 ], + | "string" : "Hello, World!", + | "long" : 12345678901234, + | "double" : 1.0123456789012346, + | "decimal" : 100.99, + | "boolean1" : true, + | "boolean2" : false, + | "nullField" : null + |} + |""".stripMargin + + sql("CREATE TABLE T (v VARIANT)") + sql(s""" + |INSERT INTO T VALUES parse_json('$json') + | """.stripMargin) + + checkAnswer( + sql(""" + |SELECT + | variant_get(v, '$.object', 'string'), + | variant_get(v, '$.object.name', 'string'), + | variant_get(v, '$.object.address.street', 'string'), + | variant_get(v, '$["object"]["address"].city', 'string'), + | variant_get(v, '$.array', 'string'), + | variant_get(v, '$.array[0]', 'int'), + | variant_get(v, '$.array[3]', 'int'), + | variant_get(v, '$.string', 'string'), + | variant_get(v, '$.double', 'double'), + | variant_get(v, '$.decimal', 'decimal(5,2)'), + | variant_get(v, '$.boolean1', 'boolean'), + | variant_get(v, '$.boolean2', 'boolean'), + | variant_get(v, '$.nullField', 'string') + |FROM T + |""".stripMargin), + Seq( + Row( + """{"address":{"city":"Hangzhou","street":"Main St"},"age":2,"name":"Apache Paimon"}""", + "Apache Paimon", + "Main St", + "Hangzhou", + "[1,2,3,4,5]", + 1, + 4, + "Hello, World!", + 1.0123456789012346, + 100.99, + true, + false, + null + )) + ) + } +} diff --git a/paimon-spark/paimon-spark3-common/src/main/scala/org/apache/spark/sql/paimon/shims/Spark3Shim.scala b/paimon-spark/paimon-spark3-common/src/main/scala/org/apache/spark/sql/paimon/shims/Spark3Shim.scala index f508e2605cbc..9b96a64fb1c4 100644 --- a/paimon-spark/paimon-spark3-common/src/main/scala/org/apache/spark/sql/paimon/shims/Spark3Shim.scala +++ b/paimon-spark/paimon-spark3-common/src/main/scala/org/apache/spark/sql/paimon/shims/Spark3Shim.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.paimon.shims +import org.apache.paimon.data.variant.Variant import org.apache.paimon.spark.catalyst.analysis.Spark3ResolutionRules import org.apache.paimon.spark.catalyst.parser.extensions.PaimonSpark3SqlExtensionsParser import org.apache.paimon.spark.data.{Spark3ArrayData, Spark3InternalRow, SparkArrayData, SparkInternalRow} @@ -71,4 +72,11 @@ class Spark3Shim extends SparkShim { override def convertToExpression(spark: SparkSession, column: Column): Expression = column.expr + override def toPaimonVariant(o: Object): Variant = throw new UnsupportedOperationException() + + override def isSparkVariantType(dataType: org.apache.spark.sql.types.DataType): Boolean = + throw new UnsupportedOperationException() + + override def SparkVariantType(): org.apache.spark.sql.types.DataType = + throw new UnsupportedOperationException() } diff --git a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/paimon/spark/data/Spark4ArrayData.scala b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/paimon/spark/data/Spark4ArrayData.scala index be319c0a9c23..d8ba2847ab88 100644 --- a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/paimon/spark/data/Spark4ArrayData.scala +++ b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/paimon/spark/data/Spark4ArrayData.scala @@ -24,6 +24,8 @@ import org.apache.spark.unsafe.types.VariantVal class Spark4ArrayData(override val elementType: DataType) extends AbstractSparkArrayData { - override def getVariant(ordinal: Int): VariantVal = throw new UnsupportedOperationException - + override def getVariant(ordinal: Int): VariantVal = { + val v = paimonArray.getVariant(ordinal) + new VariantVal(v.value(), v.metadata()) + } } diff --git a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/paimon/spark/data/Spark4InternalRow.scala b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/paimon/spark/data/Spark4InternalRow.scala index 54b0f420ea93..9ac2766346f9 100644 --- a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/paimon/spark/data/Spark4InternalRow.scala +++ b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/paimon/spark/data/Spark4InternalRow.scala @@ -24,5 +24,9 @@ import org.apache.paimon.types.RowType import org.apache.spark.unsafe.types.VariantVal class Spark4InternalRow(rowType: RowType) extends AbstractSparkInternalRow(rowType) { - override def getVariant(i: Int): VariantVal = throw new UnsupportedOperationException + + override def getVariant(i: Int): VariantVal = { + val v = row.getVariant(i) + new VariantVal(v.value(), v.metadata()) + } } diff --git a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/paimon/shims/Spark4Shim.scala b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/paimon/shims/Spark4Shim.scala index eefddafdbfb8..33eefc7d568c 100644 --- a/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/paimon/shims/Spark4Shim.scala +++ b/paimon-spark/paimon-spark4-common/src/main/scala/org/apache/spark/sql/paimon/shims/Spark4Shim.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.paimon.shims +import org.apache.paimon.data.variant.{GenericVariant, Variant} import org.apache.paimon.spark.catalyst.analysis.Spark4ResolutionRules import org.apache.paimon.spark.catalyst.parser.extensions.PaimonSpark4SqlExtensionsParser import org.apache.paimon.spark.data.{Spark4ArrayData, Spark4InternalRow, SparkArrayData, SparkInternalRow} @@ -31,7 +32,8 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.internal.ExpressionUtils -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataTypes, StructType, VariantType} +import org.apache.spark.unsafe.types.VariantVal import java.util.{Map => JMap} @@ -73,4 +75,14 @@ class Spark4Shim extends SparkShim { def convertToExpression(spark: SparkSession, column: Column): Expression = spark.expression(column) + + override def toPaimonVariant(o: Object): Variant = { + val v = o.asInstanceOf[VariantVal] + new GenericVariant(v.getValue, v.getMetadata) + } + + override def isSparkVariantType(dataType: org.apache.spark.sql.types.DataType): Boolean = + dataType.isInstanceOf[VariantType] + + override def SparkVariantType(): org.apache.spark.sql.types.DataType = DataTypes.VariantType }