Skip to content

Commit

Permalink
[spark] Integrate Variant with Spark4 (#4764)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zouxxyy authored Dec 24, 2024
1 parent c0023f0 commit 718680f
Show file tree
Hide file tree
Showing 9 changed files with 206 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.paimon.spark.sql

class VariantTest extends VariantTestBase {}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.paimon.utils.DateTimeUtils;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.paimon.shims.SparkShimLoader;

import java.io.Serializable;
import java.sql.Date;
Expand Down Expand Up @@ -145,8 +146,8 @@ public byte[] getBinary(int i) {
}

@Override
public Variant getVariant(int pos) {
throw new UnsupportedOperationException();
public Variant getVariant(int i) {
return SparkShimLoader.getSparkShim().toPaimonVariant(row.getAs(i));
}

@Override
Expand Down Expand Up @@ -307,8 +308,8 @@ public byte[] getBinary(int i) {
}

@Override
public Variant getVariant(int pos) {
throw new UnsupportedOperationException();
public Variant getVariant(int i) {
return SparkShimLoader.getSparkShim().toPaimonVariant(getAs(i));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@
import org.apache.paimon.types.TinyIntType;
import org.apache.paimon.types.VarBinaryType;
import org.apache.paimon.types.VarCharType;
import org.apache.paimon.types.VariantType;

import org.apache.spark.sql.paimon.shims.SparkShimLoader;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.LongType;
Expand Down Expand Up @@ -217,6 +219,11 @@ public DataType visit(LocalZonedTimestampType localZonedTimestampType) {
return DataTypes.TimestampType;
}

@Override
public DataType visit(VariantType variantType) {
return SparkShimLoader.getSparkShim().SparkVariantType();
}

@Override
public DataType visit(ArrayType arrayType) {
org.apache.paimon.types.DataType elementType = arrayType.getElementType();
Expand Down Expand Up @@ -381,6 +388,8 @@ public org.apache.paimon.types.DataType atomic(DataType atomic) {
} else if (atomic instanceof org.apache.spark.sql.types.TimestampNTZType) {
// Move TimestampNTZType to the end for compatibility with spark3.3 and below
return new TimestampType();
} else if (SparkShimLoader.getSparkShim().isSparkVariantType(atomic)) {
return new VariantType();
}

throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.spark.sql.paimon.shims

import org.apache.paimon.data.variant.Variant
import org.apache.paimon.spark.data.{SparkArrayData, SparkInternalRow}
import org.apache.paimon.types.{DataType, RowType}

Expand All @@ -33,7 +34,7 @@ import org.apache.spark.sql.types.StructType
import java.util.{Map => JMap}

/**
* A spark shim trait. It declare methods which have incompatible implementations between Spark 3
* A spark shim trait. It declares methods which have incompatible implementations between Spark 3
* and Spark 4. The specific SparkShim implementation will be loaded through Service Provider
* Interface.
*/
Expand Down Expand Up @@ -62,4 +63,10 @@ trait SparkShim {

def convertToExpression(spark: SparkSession, column: Column): Expression

// for variant
def toPaimonVariant(o: Object): Variant

def isSparkVariantType(dataType: org.apache.spark.sql.types.DataType): Boolean

def SparkVariantType(): org.apache.spark.sql.types.DataType
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.paimon.spark.sql

import org.apache.paimon.spark.PaimonSparkTestBase

import org.apache.spark.sql.Row

abstract class VariantTestBase extends PaimonSparkTestBase {

test("Paimon Variant: read and write variant") {
sql("CREATE TABLE T (id INT, v VARIANT)")
sql("""
|INSERT INTO T VALUES
| (1, parse_json('{"age":26,"city":"Beijing"}')),
| (2, parse_json('{"age":27,"city":"Hangzhou"}'))
| """.stripMargin)

checkAnswer(
sql(
"SELECT id, variant_get(v, '$.age', 'int'), variant_get(v, '$.city', 'string') FROM T ORDER BY id"),
Seq(Row(1, 26, "Beijing"), Row(2, 27, "Hangzhou"))
)
checkAnswer(
sql(
"SELECT variant_get(v, '$.city', 'string') FROM T WHERE variant_get(v, '$.age', 'int') == 26"),
Seq(Row("Beijing"))
)
checkAnswer(
sql("SELECT * FROM T WHERE variant_get(v, '$.age', 'int') == 27"),
sql("""SELECT 2, parse_json('{"age":27,"city":"Hangzhou"}')""")
)
}

test("Paimon Variant: read and write array variant") {
sql("CREATE TABLE T (id INT, v ARRAY<VARIANT>)")
sql(
"""
|INSERT INTO T VALUES
| (1, array(parse_json('{"age":26,"city":"Beijing"}'), parse_json('{"age":27,"city":"Hangzhou"}'))),
| (2, array(parse_json('{"age":27,"city":"Shanghai"}')))
| """.stripMargin)

withSparkSQLConf("spark.sql.ansi.enabled" -> "false") {
checkAnswer(
sql(
"SELECT id, variant_get(v[1], '$.age', 'int'), variant_get(v[0], '$.city', 'string') FROM T ORDER BY id"),
Seq(Row(1, 27, "Beijing"), Row(2, null, "Shanghai"))
)
}
}

test("Paimon Variant: complex json") {
val json =
"""
|{
| "object" : {
| "name" : "Apache Paimon",
| "age" : 2,
| "address" : {
| "street" : "Main St",
| "city" : "Hangzhou"
| }
| },
| "array" : [ 1, 2, 3, 4, 5 ],
| "string" : "Hello, World!",
| "long" : 12345678901234,
| "double" : 1.0123456789012346,
| "decimal" : 100.99,
| "boolean1" : true,
| "boolean2" : false,
| "nullField" : null
|}
|""".stripMargin

sql("CREATE TABLE T (v VARIANT)")
sql(s"""
|INSERT INTO T VALUES parse_json('$json')
| """.stripMargin)

checkAnswer(
sql("""
|SELECT
| variant_get(v, '$.object', 'string'),
| variant_get(v, '$.object.name', 'string'),
| variant_get(v, '$.object.address.street', 'string'),
| variant_get(v, '$["object"]["address"].city', 'string'),
| variant_get(v, '$.array', 'string'),
| variant_get(v, '$.array[0]', 'int'),
| variant_get(v, '$.array[3]', 'int'),
| variant_get(v, '$.string', 'string'),
| variant_get(v, '$.double', 'double'),
| variant_get(v, '$.decimal', 'decimal(5,2)'),
| variant_get(v, '$.boolean1', 'boolean'),
| variant_get(v, '$.boolean2', 'boolean'),
| variant_get(v, '$.nullField', 'string')
|FROM T
|""".stripMargin),
Seq(
Row(
"""{"address":{"city":"Hangzhou","street":"Main St"},"age":2,"name":"Apache Paimon"}""",
"Apache Paimon",
"Main St",
"Hangzhou",
"[1,2,3,4,5]",
1,
4,
"Hello, World!",
1.0123456789012346,
100.99,
true,
false,
null
))
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.spark.sql.paimon.shims

import org.apache.paimon.data.variant.Variant
import org.apache.paimon.spark.catalyst.analysis.Spark3ResolutionRules
import org.apache.paimon.spark.catalyst.parser.extensions.PaimonSpark3SqlExtensionsParser
import org.apache.paimon.spark.data.{Spark3ArrayData, Spark3InternalRow, SparkArrayData, SparkInternalRow}
Expand Down Expand Up @@ -71,4 +72,11 @@ class Spark3Shim extends SparkShim {

override def convertToExpression(spark: SparkSession, column: Column): Expression = column.expr

override def toPaimonVariant(o: Object): Variant = throw new UnsupportedOperationException()

override def isSparkVariantType(dataType: org.apache.spark.sql.types.DataType): Boolean =
throw new UnsupportedOperationException()

override def SparkVariantType(): org.apache.spark.sql.types.DataType =
throw new UnsupportedOperationException()
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import org.apache.spark.unsafe.types.VariantVal

class Spark4ArrayData(override val elementType: DataType) extends AbstractSparkArrayData {

override def getVariant(ordinal: Int): VariantVal = throw new UnsupportedOperationException

override def getVariant(ordinal: Int): VariantVal = {
val v = paimonArray.getVariant(ordinal)
new VariantVal(v.value(), v.metadata())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,9 @@ import org.apache.paimon.types.RowType
import org.apache.spark.unsafe.types.VariantVal

class Spark4InternalRow(rowType: RowType) extends AbstractSparkInternalRow(rowType) {
override def getVariant(i: Int): VariantVal = throw new UnsupportedOperationException

override def getVariant(i: Int): VariantVal = {
val v = row.getVariant(i)
new VariantVal(v.value(), v.metadata())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.spark.sql.paimon.shims

import org.apache.paimon.data.variant.{GenericVariant, Variant}
import org.apache.paimon.spark.catalyst.analysis.Spark4ResolutionRules
import org.apache.paimon.spark.catalyst.parser.extensions.PaimonSpark4SqlExtensionsParser
import org.apache.paimon.spark.data.{Spark4ArrayData, Spark4InternalRow, SparkArrayData, SparkInternalRow}
Expand All @@ -31,7 +32,8 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, Table, TableCatalog}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.internal.ExpressionUtils
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataTypes, StructType, VariantType}
import org.apache.spark.unsafe.types.VariantVal

import java.util.{Map => JMap}

Expand Down Expand Up @@ -73,4 +75,14 @@ class Spark4Shim extends SparkShim {

def convertToExpression(spark: SparkSession, column: Column): Expression =
spark.expression(column)

override def toPaimonVariant(o: Object): Variant = {
val v = o.asInstanceOf[VariantVal]
new GenericVariant(v.getValue, v.getMetadata)
}

override def isSparkVariantType(dataType: org.apache.spark.sql.types.DataType): Boolean =
dataType.isInstanceOf[VariantType]

override def SparkVariantType(): org.apache.spark.sql.types.DataType = DataTypes.VariantType
}

0 comments on commit 718680f

Please sign in to comment.