From e369dd40850690cb38efd385c1cb16ed3d09a192 Mon Sep 17 00:00:00 2001 From: msosnicki Date: Fri, 8 Nov 2024 17:41:34 +0100 Subject: [PATCH] Fix incorrect behavior in lenient tagged union decoders (#1620) * Fix incorrect behavior in lenient tagged union decoders * Rearrange things + add entry in CHANGELOG --- CHANGELOG.md | 1 + .../json/internals/SchemaVisitorJCodec.scala | 143 +++++++----------- .../json/SchemaVisitorJCodecTests.scala | 38 +++++ 3 files changed, 92 insertions(+), 90 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8821524df..2b7a81895 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ Thank you! * Adds utility types for working with endpoint handlers (see [#1612](https://github.com/disneystreaming/smithy4s/pull/1612)) * Add a more informative error message for repeated namespaces (see [#1608](https://github.com/disneystreaming/smithy4s/pull/1608)). * Adds `com.disneystreaming.smithy4s:smithy4s-protocol` dependency to the generation of `smithy-build.json` in the `smithy4sUpdateLSPConfig` tasks of the codegen plugins (see [#1610](https://github.com/disneystreaming/smithy4s/pull/1610)). +* Fix for the lenient union decoding [bug](https://github.com/disneystreaming/smithy4s/issues/1617) (see[#1620](https://github.com/disneystreaming/smithy4s/pull/1620)). # 0.18.25 diff --git a/modules/json/src/smithy4s/json/internals/SchemaVisitorJCodec.scala b/modules/json/src/smithy4s/json/internals/SchemaVisitorJCodec.scala index 354b72e81..8102667a2 100644 --- a/modules/json/src/smithy4s/json/internals/SchemaVisitorJCodec.scala +++ b/modules/json/src/smithy4s/json/internals/SchemaVisitorJCodec.scala @@ -976,30 +976,63 @@ private[smithy4s] class SchemaVisitorJCodec( private type Writer[A] = A => JsonWriter => Unit - private def taggedUnion[U]( - alternatives: Vector[Alt[U, _]] - )(dispatch: Alt.Dispatcher[U]): JCodec[U] = - new JCodec[U] { - val expecting: String = "tagged-union" + private abstract class TaggedUnionJCodec[U](alternatives: Vector[Alt[U, _]])( + dispatch: Alt.Dispatcher[U] + ) extends JCodec[U] { - override def canBeKey: Boolean = false + val expecting = "tagged-union" - def jsonLabel[A](alt: Alt[U, A]): String = - alt.hints.get(JsonName) match { - case None => alt.label - case Some(x) => x.value + override def canBeKey: Boolean = false + + def jsonLabel[A](alt: Alt[U, A]): String = + alt.hints.get(JsonName) match { + case None => alt.label + case Some(x) => x.value + } + + protected val handlerMap = + new util.HashMap[String, (Cursor, JsonReader) => U] { + def handler[A](alt: Alt[U, A]) = { + val codec = apply(alt.schema) + (cursor: Cursor, reader: JsonReader) => + alt.inject(cursor.decode(codec, reader)) } - private[this] val handlerMap = - new util.HashMap[String, (Cursor, JsonReader) => U] { - def handler[A](alt: Alt[U, A]) = { - val codec = apply(alt.schema) - (cursor: Cursor, reader: JsonReader) => - alt.inject(cursor.decode(codec, reader)) + alternatives.foreach(alt => put(jsonLabel(alt), handler(alt))) + } + + protected val precompiler = new smithy4s.schema.Alt.Precompiler[Writer] { + def apply[A](label: String, instance: Schema[A]): Writer[A] = { + val jsonLabel = + instance.hints.get(JsonName).map(_.value).getOrElse(label) + val jcodecA = instance.compile(self) + a => + out => { + out.writeObjectStart() + out.writeKey(jsonLabel) + jcodecA.encodeValue(a, out) + out.writeObjectEnd() } + } + } + protected val writer = dispatch.compile(precompiler) - alternatives.foreach(alt => put(jsonLabel(alt), handler(alt))) - } + def encodeValue(u: U, out: JsonWriter): Unit = { + writer(u)(out) + } + + def decodeKey(in: JsonReader): U = + in.decodeError("Cannot use coproducts as keys") + + def encodeKey(u: U, out: JsonWriter): Unit = + out.encodeError("Cannot use coproducts as keys") + + } + + private def taggedUnion[U]( + alternatives: Vector[Alt[U, _]] + )(dispatch: Alt.Dispatcher[U]): JCodec[U] = + new TaggedUnionJCodec[U](alternatives)(dispatch) { def decodeValue(cursor: Cursor, in: JsonReader): U = if (in.isNextToken('{')) { @@ -1020,59 +1053,12 @@ private[smithy4s] class SchemaVisitorJCodec( } } } else in.decodeError("Expected JSON object") - - val precompiler = new smithy4s.schema.Alt.Precompiler[Writer] { - def apply[A](label: String, instance: Schema[A]): Writer[A] = { - val jsonLabel = - instance.hints.get(JsonName).map(_.value).getOrElse(label) - val jcodecA = instance.compile(self) - a => - out => { - out.writeObjectStart() - out.writeKey(jsonLabel) - jcodecA.encodeValue(a, out) - out.writeObjectEnd() - } - } - } - val writer = dispatch.compile(precompiler) - - def encodeValue(u: U, out: JsonWriter): Unit = { - writer(u)(out) - } - - def decodeKey(in: JsonReader): U = - in.decodeError("Cannot use coproducts as keys") - - def encodeKey(u: U, out: JsonWriter): Unit = - out.encodeError("Cannot use coproducts as keys") } private def lenientTaggedUnion[U]( alternatives: Vector[Alt[U, _]] )(dispatch: Alt.Dispatcher[U]): JCodec[U] = - new JCodec[U] { - val expecting: String = "tagged-union" - - override def canBeKey: Boolean = false - - def jsonLabel[A](alt: Alt[U, A]): String = - alt.hints.get(JsonName) match { - case None => alt.label - case Some(x) => x.value - } - - private[this] val handlerMap = - new util.HashMap[String, (Cursor, JsonReader) => U] { - def handler[A](alt: Alt[U, A]) = { - val codec = apply(alt.schema) - (cursor: Cursor, reader: JsonReader) => - alt.inject(cursor.decode(codec, reader)) - } - - alternatives.foreach(alt => put(jsonLabel(alt), handler(alt))) - } - + new TaggedUnionJCodec[U](alternatives)(dispatch) { def decodeValue(cursor: Cursor, in: JsonReader): U = { var result: U = null.asInstanceOf[U] if (in.isNextToken('{')) { @@ -1080,6 +1066,7 @@ private[smithy4s] class SchemaVisitorJCodec( in.rollbackToken() while ({ val key = in.readKeyAsString() + cursor.push(key) val handler = handlerMap.get(key) if (handler eq null) in.skip() else if (in.isNextToken('n')) { @@ -1103,31 +1090,7 @@ private[smithy4s] class SchemaVisitorJCodec( } } else in.decodeError("Expected JSON object") } - val precompiler = new smithy4s.schema.Alt.Precompiler[Writer] { - def apply[A](label: String, instance: Schema[A]): Writer[A] = { - val jsonLabel = - instance.hints.get(JsonName).map(_.value).getOrElse(label) - val jcodecA = instance.compile(self) - a => - out => { - out.writeObjectStart() - out.writeKey(jsonLabel) - jcodecA.encodeValue(a, out) - out.writeObjectEnd() - } - } - } - val writer = dispatch.compile(precompiler) - - def encodeValue(u: U, out: JsonWriter): Unit = { - writer(u)(out) - } - - def decodeKey(in: JsonReader): U = - in.decodeError("Cannot use coproducts as keys") - def encodeKey(u: U, out: JsonWriter): Unit = - out.encodeError("Cannot use coproducts as keys") } private def untaggedUnion[U]( diff --git a/modules/json/test/src/smithy4s/json/SchemaVisitorJCodecTests.scala b/modules/json/test/src/smithy4s/json/SchemaVisitorJCodecTests.scala index a32aaabdf..c599c54d4 100644 --- a/modules/json/test/src/smithy4s/json/SchemaVisitorJCodecTests.scala +++ b/modules/json/test/src/smithy4s/json/SchemaVisitorJCodecTests.scala @@ -406,6 +406,44 @@ class SchemaVisitorJCodecTests() extends FunSuite { expect.same(readFromString[Either[Int, String]](json2), Left(1)) } + test("Lenient and regular unions have the same error messages") { + val json = """|{ + | "left" : {"foo": "b"} + |} + |""".stripMargin + + val schema = Schema.either( + Schema + .struct[String]( + Schema.string + .required[String]("bar", identity) + )(identity), + Schema + .struct[String]( + Schema.string + .required[String]("baz", identity) + )(identity) + ) + + val regularCodec = + JsoniterCodecCompilerImpl.defaultJsoniterCodecCompiler.fromSchema(schema) + val lenientCodec = + JsoniterCodecCompilerImpl.defaultJsoniterCodecCompiler.withLenientTaggedUnionDecoding + .fromSchema(schema) + + def decodeCheck(codec: JsonCodec[Either[String, String]]) = + expect.same( + Try( + readFromString[Either[String, String]](json)(codec) + ).toEither.left.map(_.getMessage), + Left("Missing required field (path: .left.bar)") + ) + + decodeCheck(regularCodec) + decodeCheck(lenientCodec) + + } + test("Untagged union are encoded / decoded") { val oneJ = """ {"three":"three_value"}""" val twoJ = """ {"four":4}"""