diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala index edc8012eb3da2..50476fd15edd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala @@ -16,12 +16,13 @@ */ package org.apache.spark.sql.catalyst.expressions.json -import java.io.CharArrayWriter +import java.io.{ByteArrayOutputStream, CharArrayWriter} -import com.fasterxml.jackson.core.JsonFactory +import com.fasterxml.jackson.core._ +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.ExprUtils +import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericInternalRow, SharedFactory} import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonGenerator, JacksonParser, JsonInferSchema, JSONOptions} import org.apache.spark.sql.catalyst.util.{ArrayData, FailFastMode, FailureSafeParser, MapData, PermissiveMode} @@ -159,3 +160,90 @@ case class StructsToJsonEvaluator( converter(value) } } + +case class JsonTupleEvaluator(fieldsLength: Int) { + + import SharedFactory._ + + // if processing fails this shared value will be returned + @transient private lazy val nullRow: Seq[InternalRow] = + new GenericInternalRow(Array.ofDim[Any](fieldsLength)) :: Nil + + private def parseRow(parser: JsonParser, fieldNames: Seq[String]): Seq[InternalRow] = { + // only objects are supported + if (parser.nextToken() != JsonToken.START_OBJECT) return nullRow + + val row = Array.ofDim[Any](fieldNames.length) + + // start reading through the token stream, looking for any requested field names + while (parser.nextToken() != JsonToken.END_OBJECT) { + if (parser.getCurrentToken == JsonToken.FIELD_NAME) { + // check to see if this field is desired in the output + val jsonField = parser.currentName + var idx = fieldNames.indexOf(jsonField) + if (idx >= 0) { + // it is, copy the child tree to the correct location in the output row + val output = new ByteArrayOutputStream() + + // write the output directly to UTF8 encoded byte array + if (parser.nextToken() != JsonToken.VALUE_NULL) { + Utils.tryWithResource(jsonFactory.createGenerator(output, JsonEncoding.UTF8)) { + generator => copyCurrentStructure(generator, parser) + } + + val jsonValue = UTF8String.fromBytes(output.toByteArray) + + // SPARK-21804: json_tuple returns null values within repeated columns + // except the first one; so that we need to check the remaining fields. + do { + row(idx) = jsonValue + idx = fieldNames.indexOf(jsonField, idx + 1) + } while (idx >= 0) + } + } + } + + // always skip children, it's cheap enough to do even if copyCurrentStructure was called + parser.skipChildren() + } + new GenericInternalRow(row) :: Nil + } + + private def copyCurrentStructure(generator: JsonGenerator, parser: JsonParser): Unit = { + parser.getCurrentToken match { + // if the user requests a string field it needs to be returned without enclosing + // quotes which is accomplished via JsonGenerator.writeRaw instead of JsonGenerator.write + case JsonToken.VALUE_STRING if parser.hasTextCharacters => + // slight optimization to avoid allocating a String instance, though the characters + // still have to be decoded... Jackson doesn't have a way to access the raw bytes + generator.writeRaw(parser.getTextCharacters, parser.getTextOffset, parser.getTextLength) + + case JsonToken.VALUE_STRING => + // the normal String case, pass it through to the output without enclosing quotes + generator.writeRaw(parser.getText) + + case JsonToken.VALUE_NULL => + // a special case that needs to be handled outside of this method. + // if a requested field is null, the result must be null. the easiest + // way to achieve this is just by ignoring null tokens entirely + throw SparkException.internalError("Do not attempt to copy a null field.") + + case _ => + // handle other types including objects, arrays, booleans and numbers + generator.copyCurrentStructure(parser) + } + } + + final def evaluate(json: UTF8String, fieldNames: Seq[String]): Seq[InternalRow] = { + if (json == null) return nullRow + try { + /* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson + detect character encoding which could fail for some malformed strings */ + Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser => + parseRow(parser, fieldNames) + } + } catch { + case _: JsonProcessingException => nullRow + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index ac6c233f7d2ea..31b7e05a4aaac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.io._ +import scala.collection.immutable.ArraySeq import scala.util.parsing.combinator.RegexParsers import com.fasterxml.jackson.core._ @@ -28,9 +29,9 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper -import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionEvalUtils, JsonExpressionUtils, JsonToStructsEvaluator, StructsToJsonEvaluator} +import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionEvalUtils, JsonExpressionUtils, JsonToStructsEvaluator, JsonTupleEvaluator, StructsToJsonEvaluator} import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, RUNTIME_REPLACEABLE, TreePattern} @@ -106,7 +107,7 @@ private[this] object JsonPathParser extends RegexParsers { } } -private[this] object SharedFactory { +private[expressions] object SharedFactory { val jsonFactory = new JsonFactoryBuilder() // The two options below enabled for Hive compatibility .enable(JsonReadFeature.ALLOW_UNESCAPED_CONTROL_CHARS) @@ -446,20 +447,8 @@ class GetJsonObjectEvaluator(cachedPath: UTF8String) { // scalastyle:on line.size.limit line.contains.tab case class JsonTuple(children: Seq[Expression]) extends Generator - with CodegenFallback with QueryErrorsBase { - import SharedFactory._ - - override def nullable: Boolean = { - // a row is always returned - false - } - - // if processing fails this shared value will be returned - @transient private lazy val nullRow: Seq[InternalRow] = - new GenericInternalRow(Array.ofDim[Any](fieldExpressions.length)) :: Nil - // the json body is the first child @transient private lazy val jsonExpr: Expression = children.head @@ -477,6 +466,11 @@ case class JsonTuple(children: Seq[Expression]) // and count the number of foldable fields, we'll use this later to optimize evaluation @transient private lazy val constantFields: Int = foldableFieldNames.count(_ != null) + override def nullable: Boolean = { + // a row is always returned + false + } + override def elementSchema: StructType = StructType(fieldExpressions.zipWithIndex.map { case (_, idx) => StructField(s"c$idx", children.head.dataType, nullable = true) }) @@ -499,29 +493,11 @@ case class JsonTuple(children: Seq[Expression]) } } + @transient + private lazy val evaluator: JsonTupleEvaluator = JsonTupleEvaluator(fieldExpressions.length) + override def eval(input: InternalRow): IterableOnce[InternalRow] = { val json = jsonExpr.eval(input).asInstanceOf[UTF8String] - if (json == null) { - return nullRow - } - - try { - /* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson - detect character encoding which could fail for some malformed strings */ - Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser => - parseRow(parser, input) - } - } catch { - case _: JsonProcessingException => - nullRow - } - } - - private def parseRow(parser: JsonParser, input: InternalRow): Seq[InternalRow] = { - // only objects are supported - if (parser.nextToken() != JsonToken.START_OBJECT) { - return nullRow - } // evaluate the field names as String rather than UTF8String to // optimize lookups from the json token, which is also a String @@ -544,66 +520,95 @@ case class JsonTuple(children: Seq[Expression]) } } - val row = Array.ofDim[Any](fieldNames.length) - - // start reading through the token stream, looking for any requested field names - while (parser.nextToken() != JsonToken.END_OBJECT) { - if (parser.getCurrentToken == JsonToken.FIELD_NAME) { - // check to see if this field is desired in the output - val jsonField = parser.currentName - var idx = fieldNames.indexOf(jsonField) - if (idx >= 0) { - // it is, copy the child tree to the correct location in the output row - val output = new ByteArrayOutputStream() + evaluator.evaluate(json, fieldNames) + } - // write the output directly to UTF8 encoded byte array - if (parser.nextToken() != JsonToken.VALUE_NULL) { - Utils.tryWithResource(jsonFactory.createGenerator(output, JsonEncoding.UTF8)) { - generator => copyCurrentStructure(generator, parser) - } + private def genFieldNamesCode( + ctx: CodegenContext, + refFoldableFieldNames: String, + fieldNamesTerm: String): String = { - val jsonValue = UTF8String.fromBytes(output.toByteArray) + def genFoldableFieldNameCode(refIndexedSeq: String, i: Int): String = { + s"(String)((scala.Option)$refIndexedSeq.apply($i)).get();" + } - // SPARK-21804: json_tuple returns null values within repeated columns - // except the first one; so that we need to check the remaining fields. - do { - row(idx) = jsonValue - idx = fieldNames.indexOf(jsonField, idx + 1) - } while (idx >= 0) + // evaluate the field names as String rather than UTF8String to + // optimize lookups from the json token, which is also a String + val (fieldNamesEval, setFieldNames) = if (constantFields == fieldExpressions.length) { + // typically the user will provide the field names as foldable expressions + // so we can use the cached copy + val s = foldableFieldNames.zipWithIndex.map { + case (v, i) => + if (v != null && v.isDefined) { + s"$fieldNamesTerm[$i] = ${genFoldableFieldNameCode(refFoldableFieldNames, i)};" + } else { + s"$fieldNamesTerm[$i] = null;" } - } } - - // always skip children, it's cheap enough to do even if copyCurrentStructure was called - parser.skipChildren() + (Seq.empty[ExprCode], s) + } else if (constantFields == 0) { + // none are foldable so all field names need to be evaluated from the input row + val f = fieldExpressions.map(_.genCode(ctx)) + val s = f.zipWithIndex.map { + case (exprCode, i) => + s""" + |if (${exprCode.isNull}) { + | $fieldNamesTerm[$i] = null; + |} else { + | $fieldNamesTerm[$i] = ${exprCode.value}.toString(); + |} + |""".stripMargin + } + (f, s) + } else { + // if there is a mix of constant and non-constant expressions + // prefer the cached copy when available + val codes = foldableFieldNames.zip(fieldExpressions).zipWithIndex.map { + case ((null, expr: Expression), i) => + val f = expr.genCode(ctx) + val s = + s""" + |if (${f.isNull}) { + | $fieldNamesTerm[$i] = null; + |} else { + | $fieldNamesTerm[$i] = ${f.value}.toString(); + |} + |""".stripMargin + (Some(f), s) + case ((v: Option[String], _), i) => + val s = if (v.isDefined) { + s"$fieldNamesTerm[$i] = ${genFoldableFieldNameCode(refFoldableFieldNames, i)};" + } else { + s"$fieldNamesTerm[$i] = null;" + } + (None, s) + } + (codes.filter(c => c._1.isDefined).map(c => c._1.get), codes.map(c => c._2)) } - new GenericInternalRow(row) :: Nil + s""" + |String[] $fieldNamesTerm = new String[${fieldExpressions.length}]; + |${fieldNamesEval.map(_.code).mkString("\n")} + |${setFieldNames.mkString("\n")} + |""".stripMargin } - private def copyCurrentStructure(generator: JsonGenerator, parser: JsonParser): Unit = { - parser.getCurrentToken match { - // if the user requests a string field it needs to be returned without enclosing - // quotes which is accomplished via JsonGenerator.writeRaw instead of JsonGenerator.write - case JsonToken.VALUE_STRING if parser.hasTextCharacters => - // slight optimization to avoid allocating a String instance, though the characters - // still have to be decoded... Jackson doesn't have a way to access the raw bytes - generator.writeRaw(parser.getTextCharacters, parser.getTextOffset, parser.getTextLength) - - case JsonToken.VALUE_STRING => - // the normal String case, pass it through to the output without enclosing quotes - generator.writeRaw(parser.getText) - - case JsonToken.VALUE_NULL => - // a special case that needs to be handled outside of this method. - // if a requested field is null, the result must be null. the easiest - // way to achieve this is just by ignoring null tokens entirely - throw SparkException.internalError("Do not attempt to copy a null field.") - - case _ => - // handle other types including objects, arrays, booleans and numbers - generator.copyCurrentStructure(parser) - } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val refEvaluator = ctx.addReferenceObj("evaluator", evaluator) + val refFoldableFieldNames = ctx.addReferenceObj("foldableFieldNames", foldableFieldNames) + val wrapperClass = classOf[Seq[_]].getName + val jsonEval = jsonExpr.genCode(ctx) + val fieldNamesTerm = ctx.freshName("fieldNames") + val fieldNamesCode = genFieldNamesCode(ctx, refFoldableFieldNames, fieldNamesTerm) + val fieldNamesClz = classOf[ArraySeq[_]].getName + ev.copy(code = + code""" + |${jsonEval.code} + |$fieldNamesCode + |boolean ${ev.isNull} = false; + |$wrapperClass ${ev.value} = $refEvaluator.evaluate( + | ${jsonEval.value}, new $fieldNamesClz.ofRef($fieldNamesTerm)); + |""".stripMargin) } override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): JsonTuple = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 3a58cb92cecf2..bb32a6d7adf26 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -273,8 +273,9 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with } test("json_tuple escaping") { - GenerateUnsafeProjection.generate( - JsonTuple(Literal("\"quote") :: Literal("\"quote") :: Nil) :: Nil) + checkJsonTuple( + JsonTuple(Literal("\"quote") :: Literal("\"quote") :: Nil), + InternalRow.fromSeq(Seq(null).map(UTF8String.fromString))) } test("json_tuple - hive key 1") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 84408d8e2495d..909a0db6473d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -1456,4 +1456,23 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { assert(plan.isInstanceOf[WholeStageCodegenExec]) checkAnswer(df, Row(null)) } + + test("function json_tuple codegen - field name foldable optimize") { + withTempView("t") { + val df = Seq(("""{"a":1, "b":2}""", "a", "b")).toDF("json", "c1", "c2") + df.createOrReplaceTempView("t") + + // all field names are non-foldable + val df1 = sql("SELECT json_tuple(json, c1, c2) from t") + checkAnswer(df1, Row("1", "2")) + + // some foldable, some non-foldable + val df2 = sql("SELECT json_tuple(json, 'a', c2) from t") + checkAnswer(df2, Row("1", "2")) + + // all field names are foldable + val df3 = sql("SELECT json_tuple(json, 'a', 'b') from t") + checkAnswer(df3, Row("1", "2")) + } + } }