Skip to content

Commit

Permalink
[SPARK-49043][SQL] Fix interpreted codepath group by on map containin…
Browse files Browse the repository at this point in the history
…g collated strings

### What changes were proposed in this pull request?
Added ordering for PhysicalMapType in `PhysicalDataType.scala`.

### Why are the changes needed?
This feature is needed to compare maps for equality in group-by queries when they contain collated strings.
It was already functional in the codegen path.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Added tests to `CollationSuite.scala`

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#47521 from ilicmarkodb/fix_group_by_on_map.

Lead-authored-by: Marko <[email protected]>
Co-authored-by: Marko Ilić <[email protected]>
Co-authored-by: Marko Ilic <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
ilicmarkodb authored and MaxGekk committed Aug 20, 2024
1 parent b4a8029 commit 899fad4
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.reflect.runtime.universe.TypeTag
import scala.reflect.runtime.universe.typeTag

import org.apache.spark.sql.catalyst.expressions.{Ascending, BoundReference, InterpretedOrdering, SortOrder}
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, SQLOrderingUtil}
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, MapData, SQLOrderingUtil}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteExactNumeric, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalExactNumeric, DecimalType, DoubleExactNumeric, DoubleType, FloatExactNumeric, FloatType, FractionalType, IntegerExactNumeric, IntegerType, IntegralType, LongExactNumeric, LongType, MapType, NullType, NumericType, ShortExactNumeric, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType}
Expand Down Expand Up @@ -234,10 +234,72 @@ case object PhysicalLongType extends PhysicalLongType

case class PhysicalMapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean)
extends PhysicalDataType {
override private[sql] def ordering =
throw QueryExecutionErrors.orderedOperationUnsupportedByDataTypeError("PhysicalMapType")
override private[sql] type InternalType = Any
// maps are not orderable, we use `ordering` just to support group by queries
override private[sql] def ordering = interpretedOrdering
override private[sql] type InternalType = MapData
@transient private[sql] lazy val tag = typeTag[InternalType]

@transient
private[sql] lazy val interpretedOrdering: Ordering[MapData] = new Ordering[MapData] {
private[this] val keyOrdering =
PhysicalDataType(keyType).ordering.asInstanceOf[Ordering[Any]]
private[this] val valuesOrdering =
PhysicalDataType(valueType).ordering.asInstanceOf[Ordering[Any]]

override def compare(left: MapData, right: MapData): Int = {
val lengthLeft = left.numElements()
val lengthRight = right.numElements()
val keyArrayLeft = left.keyArray()
val valueArrayLeft = left.valueArray()
val keyArrayRight = right.keyArray()
val valueArrayRight = right.valueArray()
val minLength = math.min(lengthLeft, lengthRight)
var i = 0
while (i < minLength) {
var comp = compareElements(keyArrayLeft, keyArrayRight, keyType, i, keyOrdering)
if (comp != 0) {
return comp
}
comp = compareElements(valueArrayLeft, valueArrayRight, valueType, i, valuesOrdering)
if (comp != 0) {
return comp
}

i += 1
}

if (lengthLeft < lengthRight) {
-1
} else if (lengthLeft > lengthRight) {
1
} else {
0
}
}

private def compareElements(
arrayLeft: ArrayData,
arrayRight: ArrayData,
dataType: DataType,
position: Int,
ordering: Ordering[Any]): Int = {
val isNullLeft = arrayLeft.isNullAt(position)
val isNullRight = arrayRight.isNullAt(position)

if (isNullLeft && isNullRight) {
0
} else if (isNullLeft) {
-1
} else if (isNullRight) {
1
} else {
ordering.compare(
arrayLeft.get(position, dataType),
arrayRight.get(position, dataType)
)
}
}
}
}

class PhysicalNullType() extends PhysicalDataType with PhysicalPrimitiveType {
Expand Down
10 changes: 2 additions & 8 deletions sql/core/src/test/resources/sql-tests/results/mode.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,9 @@ struct<mode(col):map<int,string>>
-- !query
SELECT mode(col, true) FROM VALUES (map(1, 'a')) AS tab(col)
-- !query schema
struct<>
struct<mode() WITHIN GROUP (ORDER BY col DESC):map<int,string>>
-- !query output
org.apache.spark.SparkIllegalArgumentException
{
"errorClass" : "_LEGACY_ERROR_TEMP_2005",
"messageParameters" : {
"dataType" : "PhysicalMapType"
}
}
{1:"a"}


-- !query
Expand Down
117 changes: 117 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,123 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
}
}

for (collation <- Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI", "")) {
for (codeGen <- Seq("NO_CODEGEN", "CODEGEN_ONLY")) {
val collationSetup = if (collation.isEmpty) "" else " COLLATE " + collation
val supportsBinaryEquality = collation.isEmpty || collation == "UNICODE" ||
CollationFactory.fetchCollation(collation).supportsBinaryEquality

test(s"Group by on map containing$collationSetup strings ($codeGen)") {
val tableName = "t"

withTable(tableName) {
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codeGen) {
sql(s"create table $tableName" +
s" (m map<string$collationSetup, string$collationSetup>)")
sql(s"insert into $tableName values (map('aaa', 'AAA'))")
sql(s"insert into $tableName values (map('AAA', 'aaa'))")
sql(s"insert into $tableName values (map('aaa', 'AAA'))")
sql(s"insert into $tableName values (map('bbb', 'BBB'))")
sql(s"insert into $tableName values (map('aAA', 'AaA'))")
sql(s"insert into $tableName values (map('BBb', 'bBB'))")
sql(s"insert into $tableName values (map('aaaa', 'AAA'))")

val df = sql(s"select count(*) from $tableName group by m")
if (supportsBinaryEquality) {
checkAnswer(df, Seq(Row(2), Row(1), Row(1), Row(1), Row(1), Row(1)))
} else {
checkAnswer(df, Seq(Row(4), Row(2), Row(1)))
}
}
}
}

test(s"Group by on map containing structs with $collationSetup strings ($codeGen)") {
val tableName = "t"

withTable(tableName) {
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codeGen) {
sql(s"create table $tableName" +
s" (m map<struct<fld1: string$collationSetup, fld2: string$collationSetup>, " +
s"struct<fld1: string$collationSetup, fld2: string$collationSetup>>)")
sql(s"insert into $tableName values " +
s"(map(struct('aaa', 'bbb'), struct('ccc', 'ddd')))")
sql(s"insert into $tableName values " +
s"(map(struct('Aaa', 'BBB'), struct('cCC', 'dDd')))")
sql(s"insert into $tableName values " +
s"(map(struct('AAA', 'BBb'), struct('cCc', 'DDD')))")
sql(s"insert into $tableName values " +
s"(map(struct('aaa', 'bbB'), struct('CCC', 'DDD')))")

val df = sql(s"select count(*) from $tableName group by m")
if (supportsBinaryEquality) {
checkAnswer(df, Seq(Row(1), Row(1), Row(1), Row(1)))
} else {
checkAnswer(df, Seq(Row(4)))
}
}
}
}

test(s"Group by on map containing arrays with$collationSetup strings ($codeGen)") {
val tableName = "t"

withTable(tableName) {
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codeGen) {
sql(s"create table $tableName " +
s"(m map<array<string$collationSetup>, array<string$collationSetup>>)")
sql(s"insert into $tableName values (map(array('aaa', 'bbb'), array('ccc', 'ddd')))")
sql(s"insert into $tableName values (map(array('AAA', 'BbB'), array('Ccc', 'ddD')))")
sql(s"insert into $tableName values (map(array('AAA', 'BbB', 'Ccc'), array('ddD')))")
sql(s"insert into $tableName values (map(array('aAa', 'Bbb'), array('CCC', 'DDD')))")
sql(s"insert into $tableName values (map(array('AAa', 'BBb'), array('cCC', 'DDd')))")
sql(s"insert into $tableName values (map(array('AAA', 'BBB', 'CCC'), array('DDD')))")

val df = sql(s"select count(*) from $tableName group by m")
if (supportsBinaryEquality) {
checkAnswer(df, Seq(Row(1), Row(1), Row(1), Row(1), Row(1), Row(1)))
} else {
checkAnswer(df, Seq(Row(4), Row(2)))
}
}
}
}

test(s"Check that order by on map with$collationSetup strings fails ($codeGen)") {
val tableName = "t"
withTable(tableName) {
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codeGen) {
sql(s"create table $tableName" +
s" (m map<string$collationSetup, string$collationSetup>, " +
s" c integer)")
sql(s"insert into $tableName values (map('aaa', 'AAA'), 1)")
sql(s"insert into $tableName values (map('BBb', 'bBB'), 2)")

// `collationSetupError` is created because "COLLATE UTF8_BINARY" is omitted in data
// type in checkError
val collationSetupError = if (collation != "UTF8_BINARY") collationSetup else ""
val query = s"select c from $tableName order by m"
val ctx = "m"
checkError(
exception = intercept[AnalysisException](sql(query)),
errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE",
parameters = Map(
"functionName" -> "`sortorder`",
"dataType" -> s"\"MAP<STRING$collationSetupError, STRING$collationSetupError>\"",
"sqlExpr" -> "\"m ASC NULLS FIRST\""
),
context = ExpectedContext(
fragment = ctx,
start = query.length - ctx.length,
stop = query.length - 1
)
)
}
}
}
}
}

test("Support operations on complex types containing collated strings") {
checkAnswer(sql("select reverse('abc' collate utf8_lcase)"), Seq(Row("cba")))
checkAnswer(sql(
Expand Down

0 comments on commit 899fad4

Please sign in to comment.