Skip to content

Commit

Permalink
[SPARK-26370][SQL] Fix resolution of higher-order function for the sa…
Browse files Browse the repository at this point in the history
…me identifier.

## What changes were proposed in this pull request?

When using a higher-order function with the same variable name as the existing columns in `Filter` or something which uses `Analyzer.resolveExpressionBottomUp` during the resolution, e.g.,:

```scala
val df = Seq(
  (Seq(1, 9, 8, 7), 1, 2),
  (Seq(5, 9, 7), 2, 2),
  (Seq.empty, 3, 2),
  (null, 4, 2)
).toDF("i", "x", "d")

checkAnswer(df.filter("exists(i, x -> x % d == 0)"),
  Seq(Row(Seq(1, 9, 8, 7), 1, 2)))
checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"),
  Seq(Row(1)))
```

the following exception happens:

```
java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.BoundReference cannot be cast to org.apache.spark.sql.catalyst.expressions.NamedExpression
  at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:237)
  at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
  at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
  at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
  at scala.collection.TraversableLike.map(TraversableLike.scala:237)
  at scala.collection.TraversableLike.map$(TraversableLike.scala:230)
  at scala.collection.AbstractTraversable.map(Traversable.scala:108)
  at org.apache.spark.sql.catalyst.expressions.HigherOrderFunction.$anonfun$functionsForEval$1(higherOrderFunctions.scala:147)
  at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:237)
  at scala.collection.immutable.List.foreach(List.scala:392)
  at scala.collection.TraversableLike.map(TraversableLike.scala:237)
  at scala.collection.TraversableLike.map$(TraversableLike.scala:230)
  at scala.collection.immutable.List.map(List.scala:298)
  at org.apache.spark.sql.catalyst.expressions.HigherOrderFunction.functionsForEval(higherOrderFunctions.scala:145)
  at org.apache.spark.sql.catalyst.expressions.HigherOrderFunction.functionsForEval$(higherOrderFunctions.scala:145)
  at org.apache.spark.sql.catalyst.expressions.ArrayExists.functionsForEval$lzycompute(higherOrderFunctions.scala:369)
  at org.apache.spark.sql.catalyst.expressions.ArrayExists.functionsForEval(higherOrderFunctions.scala:369)
  at org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.functionForEval(higherOrderFunctions.scala:176)
  at org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.functionForEval$(higherOrderFunctions.scala:176)
  at org.apache.spark.sql.catalyst.expressions.ArrayExists.functionForEval(higherOrderFunctions.scala:369)
  at org.apache.spark.sql.catalyst.expressions.ArrayExists.nullSafeEval(higherOrderFunctions.scala:387)
  at org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.eval(higherOrderFunctions.scala:190)
  at org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.eval$(higherOrderFunctions.scala:185)
  at org.apache.spark.sql.catalyst.expressions.ArrayExists.eval(higherOrderFunctions.scala:369)
  at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificPredicate.eval(Unknown Source)
  at org.apache.spark.sql.execution.FilterExec.$anonfun$doExecute$3(basicPhysicalOperators.scala:216)
  at org.apache.spark.sql.execution.FilterExec.$anonfun$doExecute$3$adapted(basicPhysicalOperators.scala:215)

...
```

because the `UnresolvedAttribute`s in `LambdaFunction` are unexpectedly resolved by the rule.

This pr modified to use a placeholder `UnresolvedNamedLambdaVariable` to prevent unexpected resolution.

## How was this patch tested?

Added a test and modified some tests.

Closes apache#23320 from ueshin/issues/SPARK-26370/hof_resolution.

Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
ueshin authored and cloud-fan committed Dec 14, 2018
1 parent 2d8838d commit 3dda58a
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,14 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] {
val lambdaMap = l.arguments.map(v => canonicalizer(v.name) -> v).toMap
l.mapChildren(resolve(_, parentLambdaMap ++ lambdaMap))

case u @ UnresolvedAttribute(name +: nestedFields) =>
case u @ UnresolvedNamedLambdaVariable(name +: nestedFields) =>
parentLambdaMap.get(canonicalizer(name)) match {
case Some(lambda) =>
nestedFields.foldLeft(lambda: Expression) { (expr, fieldName) =>
ExtractValue(expr, Literal(fieldName), conf.resolver)
}
case None => u
case None =>
UnresolvedAttribute(u.nameParts)
}

case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,34 @@ import java.util.concurrent.atomic.AtomicReference
import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedException}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.array.ByteArrayMethods

/**
* A placeholder of lambda variables to prevent unexpected resolution of [[LambdaFunction]].
*/
case class UnresolvedNamedLambdaVariable(nameParts: Seq[String])
extends LeafExpression with NamedExpression with Unevaluable {

override def name: String =
nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")

override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override def qualifier: Seq[String] = throw new UnresolvedException(this, "qualifier")
override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance")
override lazy val resolved = false

override def toString: String = s"lambda '$name"

override def sql: String = name
}

/**
* A named lambda variable.
*/
Expand Down Expand Up @@ -79,7 +101,7 @@ case class LambdaFunction(

object LambdaFunction {
val identity: LambdaFunction = {
val id = UnresolvedAttribute.quoted("id")
val id = UnresolvedNamedLambdaVariable(Seq("id"))
LambdaFunction(id, Seq(id))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1338,9 +1338,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
*/
override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) {
val arguments = ctx.IDENTIFIER().asScala.map { name =>
UnresolvedAttribute.quoted(name.getText)
UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts)
}
LambdaFunction(expression(ctx.expression), arguments)
val function = expression(ctx.expression).transformUp {
case a: UnresolvedAttribute => UnresolvedNamedLambdaVariable(a.nameParts)
}
LambdaFunction(function, arguments)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,21 @@ class ResolveLambdaVariablesSuite extends PlanTest {
comparePlans(Analyzer.execute(plan(e1)), plan(e2))
}

private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))

test("resolution - no op") {
checkExpression(key, key)
}

test("resolution - simple") {
val in = ArrayTransform(values1, LambdaFunction('x.attr + 1, 'x.attr :: Nil))
val in = ArrayTransform(values1, LambdaFunction(lv('x) + 1, lv('x) :: Nil))
val out = ArrayTransform(values1, LambdaFunction(lvInt + 1, lvInt :: Nil))
checkExpression(in, out)
}

test("resolution - nested") {
val in = ArrayTransform(values2, LambdaFunction(
ArrayTransform('x.attr, LambdaFunction('x.attr + 1, 'x.attr :: Nil)), 'x.attr :: Nil))
ArrayTransform(lv('x), LambdaFunction(lv('x) + 1, lv('x) :: Nil)), lv('x) :: Nil))
val out = ArrayTransform(values2, LambdaFunction(
ArrayTransform(lvArray, LambdaFunction(lvInt + 1, lvInt :: Nil)), lvArray :: Nil))
checkExpression(in, out)
Expand All @@ -75,14 +77,14 @@ class ResolveLambdaVariablesSuite extends PlanTest {

test("fail - name collisions") {
val p = plan(ArrayTransform(values1,
LambdaFunction('x.attr + 'X.attr, 'x.attr :: 'X.attr :: Nil)))
LambdaFunction(lv('x) + lv('X), lv('x) :: lv('X) :: Nil)))
val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
assert(msg.contains("arguments should not have names that are semantically the same"))
}

test("fail - lambda arguments") {
val p = plan(ArrayTransform(values1,
LambdaFunction('x.attr + 'y.attr + 'z.attr, 'x.attr :: 'y.attr :: 'z.attr :: Nil)))
LambdaFunction(lv('x) + lv('y) + lv('z), lv('x) :: lv('y) :: lv('z) :: Nil)))
val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
assert(msg.contains("does not match the number of arguments expected"))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or}
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
Expand Down Expand Up @@ -306,22 +306,24 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testProjection(originalExpr = column, expectedExpr = column)
}

private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))

test("replace nulls in lambda function of ArrayFilter") {
testHigherOrderFunc('a, ArrayFilter, Seq('e))
testHigherOrderFunc('a, ArrayFilter, Seq(lv('e)))
}

test("replace nulls in lambda function of ArrayExists") {
testHigherOrderFunc('a, ArrayExists, Seq('e))
testHigherOrderFunc('a, ArrayExists, Seq(lv('e)))
}

test("replace nulls in lambda function of MapFilter") {
testHigherOrderFunc('m, MapFilter, Seq('k, 'v))
testHigherOrderFunc('m, MapFilter, Seq(lv('k), lv('v)))
}

test("inability to replace nulls in arbitrary higher-order function") {
val lambdaFunc = LambdaFunction(
function = If('e > 0, Literal(null, BooleanType), TrueLiteral),
arguments = Seq[NamedExpression]('e))
function = If(lv('e) > 0, Literal(null, BooleanType), TrueLiteral),
arguments = Seq[NamedExpression](lv('e)))
val column = ArrayTransform('a, lambdaFunc)
testProjection(originalExpr = column, expectedExpr = column)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,11 @@ class ExpressionParserSuite extends PlanTest {
intercept("foo(a x)", "extraneous input 'x'")
}

private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))

test("lambda functions") {
assertEqual("x -> x + 1", LambdaFunction('x + 1, Seq('x.attr)))
assertEqual("(x, y) -> x + y", LambdaFunction('x + 'y, Seq('x.attr, 'y.attr)))
assertEqual("x -> x + 1", LambdaFunction(lv('x) + 1, Seq(lv('x))))
assertEqual("(x, y) -> x + y", LambdaFunction(lv('x) + lv('y), Seq(lv('x), lv('y))))
}

test("window function expressions") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ FROM various_maps
struct<>
-- !query 5 output
org.apache.spark.sql.AnalysisException
cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), `k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7
cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), k, NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7


-- !query 6
Expand Down Expand Up @@ -113,7 +113,7 @@ FROM various_maps
struct<>
-- !query 8 output
org.apache.spark.sql.AnalysisException
cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), `k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7
cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), k, NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7


-- !query 9
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2908,6 +2908,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
}
assert(ex.getMessage.contains("Cannot use null as map key"))
}

test("SPARK-26370: Fix resolution of higher-order function for the same identifier") {
val df = Seq(
(Seq(1, 9, 8, 7), 1, 2),
(Seq(5, 9, 7), 2, 2),
(Seq.empty, 3, 2),
(null, 4, 2)
).toDF("i", "x", "d")

checkAnswer(df.selectExpr("x", "exists(i, x -> x % d == 0)"),
Seq(
Row(1, true),
Row(2, false),
Row(3, false),
Row(4, null)))
checkAnswer(df.filter("exists(i, x -> x % d == 0)"),
Seq(Row(Seq(1, 9, 8, 7), 1, 2)))
checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"),
Seq(Row(1)))
}
}

object DataFrameFunctionsSuite {
Expand Down

0 comments on commit 3dda58a

Please sign in to comment.