Skip to content

Commit

Permalink
Json handling improvements (#404)
Browse files Browse the repository at this point in the history
* implemented json interpolator
* implemented default arguments for derived JsonFormats
* changed default imports for besom.json.* and moved default behavior to besom.json.custom.*
* added docs for besom-json
---------

Co-authored-by: Paweł Prażak <[email protected]>
  • Loading branch information
lbialy and pawelprazak authored Apr 8, 2024
1 parent 0673b6c commit 4b10f3f
Show file tree
Hide file tree
Showing 16 changed files with 843 additions and 45 deletions.
6 changes: 5 additions & 1 deletion besom-json/src/main/scala/besom/json/JsonFormat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ object JsonReader {
implicit def func2Reader[T](f: JsValue => T): JsonReader[T] = new JsonReader[T] {
def read(json: JsValue) = f(json)
}

inline def derived[T <: Product](using JsonProtocol): JsonReader[T] = summon[JsonProtocol].jsonFormatN[T]
}

/** Provides the JSON serialization for type T.
Expand All @@ -44,14 +46,16 @@ object JsonWriter {
implicit def func2Writer[T](f: T => JsValue): JsonWriter[T] = new JsonWriter[T] {
def write(obj: T) = f(obj)
}

inline def derived[T <: Product](using JsonProtocol): JsonWriter[T] = summon[JsonProtocol].jsonFormatN[T]
}

/** Provides the JSON deserialization and serialization for type T.
*/
trait JsonFormat[T] extends JsonReader[T] with JsonWriter[T]

object JsonFormat:
inline def derived[T <: Product](using JsonProtocol) = summon[JsonProtocol].jsonFormatN[T]
inline def derived[T <: Product](using JsonProtocol): JsonFormat[T] = summon[JsonProtocol].jsonFormatN[T]

/** A special JsonReader capable of reading a legal JSON root object, i.e. either a JSON array or a JSON object.
*/
Expand Down
58 changes: 46 additions & 12 deletions besom-json/src/main/scala/besom/json/ProductFormats.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,44 @@ package besom.json
trait ProductFormats:
self: StandardFormats with AdditionalFormats =>

def writeNulls: Boolean = false
def writeNulls: Boolean = false
def requireNullsForOptions: Boolean = false

inline def jsonFormatN[T <: Product]: RootJsonFormat[T] = ${ ProductFormatsMacro.jsonFormatImpl[T]('self) }

object ProductFormatsMacro:
import scala.deriving.*
import scala.quoted.*

private def findDefaultParams[T](using quotes: Quotes, tpe: Type[T]): Expr[Map[String, Any]] =
import quotes.reflect.*

TypeRepr.of[T].classSymbol match
case None => '{ Map.empty[String, Any] }
case Some(sym) =>
val comp = sym.companionClass
try
val mod = Ref(sym.companionModule)
val names =
for p <- sym.caseFields if p.flags.is(Flags.HasDefault)
yield p.name
val namesExpr: Expr[List[String]] =
Expr.ofList(names.map(Expr(_)))

val body = comp.tree.asInstanceOf[ClassDef].body
val idents: List[Ref] =
for
case deff @ DefDef(name, _, _, _) <- body
if name.startsWith("$lessinit$greater$default")
yield mod.select(deff.symbol)
val typeArgs = TypeRepr.of[T].typeArgs
val identsExpr: Expr[List[Any]] =
if typeArgs.isEmpty then Expr.ofList(idents.map(_.asExpr))
else Expr.ofList(idents.map(_.appliedToTypes(typeArgs).asExpr))

'{ $namesExpr.zip($identsExpr).toMap }
catch case cce: ClassCastException => '{ Map.empty[String, Any] } // TODO drop after https://github.com/lampepfl/dotty/issues/19732

def jsonFormatImpl[T <: Product: Type](prodFormats: Expr[ProductFormats])(using Quotes): Expr[RootJsonFormat[T]] =
Expr.summon[Mirror.Of[T]].get match
case '{
Expand All @@ -50,25 +80,29 @@ object ProductFormatsMacro:

// instances are in correct order of fields of the product
val allInstancesExpr = Expr.ofList(prepareInstances(Type.of[elementLabels], Type.of[elementTypes]))
val defaultArguments = findDefaultParams[T]

'{
new RootJsonFormat[T]:
private val allInstances = ${ allInstancesExpr }
private val fmts = ${ prodFormats }
private val defaultArgs = ${ defaultArguments }

def read(json: JsValue): T = json match
case JsObject(fields) =>
val values = allInstances.map { case (fieldName, fieldFormat, isOption) =>
val fieldValue =
try fieldFormat.read(fields(fieldName))
catch
case e: NoSuchElementException =>
if isOption then None
else
throw DeserializationException("Object is missing required member '" ++ fieldName ++ "'", null, fieldName :: Nil)
case DeserializationException(msg, cause, fieldNames) =>
throw DeserializationException(msg, cause, fieldName :: fieldNames)

fieldValue
try fieldFormat.read(fields(fieldName))
catch
case e: NoSuchElementException =>
// if field has a default value, use it, we didn't find anything in the JSON
if defaultArgs.contains(fieldName) then defaultArgs(fieldName)
// if field is optional and requireNullsForOptions is disabled, return None
// otherwise we require an explicit null value
else if isOption && !fmts.requireNullsForOptions then None
// it's missing so we throw an exception
else throw DeserializationException("Object is missing required member '" ++ fieldName ++ "'", null, fieldName :: Nil)
case DeserializationException(msg, cause, fieldNames) =>
throw DeserializationException(msg, cause, fieldName :: fieldNames)
}
$m.fromProduct(Tuple.fromArray(values.toArray))

Expand Down
63 changes: 44 additions & 19 deletions besom-json/src/main/scala/besom/json/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,62 @@
* limitations under the License.
*/

package besom
package besom.json

import scala.language.implicitConversions

package object json {
def deserializationError(msg: String, cause: Throwable = null, fieldNames: List[String] = Nil) =
throw new DeserializationException(msg, cause, fieldNames)

type JsField = (String, JsValue)
def serializationError(msg: String) = throw new SerializationException(msg)

case class DeserializationException(msg: String, cause: Throwable = null, fieldNames: List[String] = Nil)
extends RuntimeException(msg, cause)
class SerializationException(msg: String) extends RuntimeException(msg)

def deserializationError(msg: String, cause: Throwable = null, fieldNames: List[String] = Nil) =
throw new DeserializationException(msg, cause, fieldNames)
def serializationError(msg: String) = throw new SerializationException(msg)
private[json] class RichAny[T](any: T) {
def toJson(implicit writer: JsonWriter[T]): JsValue = writer.write(any)
}

private[json] class RichString(string: String) {
def parseJson: JsValue = JsonParser(string)
def parseJson(settings: JsonParserSettings): JsValue = JsonParser(string, settings)
}

private[json] trait DefaultExports:
type JsField = (String, JsValue)

def jsonReader[T](implicit reader: JsonReader[T]) = reader
def jsonWriter[T](implicit writer: JsonWriter[T]) = writer

implicit def enrichAny[T](any: T): RichAny[T] = new RichAny(any)
implicit def enrichString(string: String): RichString = new RichString(string)
}

package json {
private[json] trait DefaultProtocol:
implicit val defaultProtocol: JsonProtocol = DefaultJsonProtocol

case class DeserializationException(msg: String, cause: Throwable = null, fieldNames: List[String] = Nil)
extends RuntimeException(msg, cause)
class SerializationException(msg: String) extends RuntimeException(msg)
/** This allows to perform a single import: `import besom.json.*` to get basic JSON behaviour. If you need to extend JSON handling in any
* way, please use `import besom.json.custom.*`, then extend `DefaultJsonProtocol`:
*
* ```
* object MyCustomJsonProtocol extends DefaultJsonProtocol:
* given someCustomTypeFormat: JsonFormat[A] = ...
* ```
* build your customized protocol that way and set it up for your `derives` clauses using:
* ```
* given JsonProtocol = MyCustomJsonProtocol
*
* case class MyCaseClass(a: String, b: Int) derives JsonFormat
* ```
*/
object custom extends DefaultExports:
export besom.json.{JsonProtocol, DefaultJsonProtocol}
export besom.json.{JsonFormat, JsonReader, JsonWriter}
export besom.json.{RootJsonFormat, RootJsonReader, RootJsonWriter}
export besom.json.{DeserializationException, SerializationException}
export besom.json.{JsValue, JsObject, JsArray, JsString, JsNumber, JsBoolean, JsNull}

private[json] class RichAny[T](any: T) {
def toJson(implicit writer: JsonWriter[T]): JsValue = writer.write(any)
}
object DefaultJsonExports extends DefaultExports with DefaultProtocol

private[json] class RichString(string: String) {
def parseJson: JsValue = JsonParser(string)
def parseJson(settings: JsonParserSettings): JsValue = JsonParser(string, settings)
}
}
export DefaultJsonExports.*
export DefaultJsonProtocol.*
73 changes: 70 additions & 3 deletions besom-json/src/test/scala/besom/json/DerivedFormatsSpec.scala
Original file line number Diff line number Diff line change
@@ -1,18 +1,85 @@
package besom.json
package besom.json.test

import org.specs2.mutable.*

class DerivedFormatsSpec extends Specification {

"The derives keyword" should {
"behave as expected" in {
import DefaultJsonProtocol.*
given JsonProtocol = DefaultJsonProtocol
import besom.json.*

case class Color(name: String, red: Int, green: Int, blue: Int) derives JsonFormat
val color = Color("CadetBlue", 95, 158, 160)

color.toJson.convertTo[Color] mustEqual color
}

"be able to support default argument values" in {
import besom.json.*

case class Color(name: String, red: Int, green: Int, blue: Int = 160) derives JsonFormat
val color = Color("CadetBlue", 95, 158)

val json = """{"name":"CadetBlue","red":95,"green":158}"""

color.toJson.convertTo[Color] mustEqual color
json.parseJson.convertTo[Color] mustEqual color
}

"be able to support missing fields when there are default argument values" in {
import besom.json.*

case class Color(name: String, red: Int, green: Int, blue: Option[Int] = None) derives JsonFormat
val color = Color("CadetBlue", 95, 158)

val json = """{"green":158,"red":95,"name":"CadetBlue"}"""

color.toJson.compactPrint mustEqual json
color.toJson.convertTo[Color] mustEqual color
json.parseJson.convertTo[Color] mustEqual color
}

"be able to write and read nulls for optional fields" in {
import besom.json.custom.*

locally {
given jp: JsonProtocol = new DefaultJsonProtocol {
override def writeNulls = true
override def requireNullsForOptions = true
}
import jp.*

case class Color(name: String, red: Int, green: Int, blue: Option[Int]) derives JsonFormat
val color = Color("CadetBlue", 95, 158, None)

val json = """{"blue":null,"green":158,"red":95,"name":"CadetBlue"}"""

color.toJson.compactPrint mustEqual json

color.toJson.convertTo[Color] mustEqual color
json.parseJson.convertTo[Color] mustEqual color

val noExplicitNullJson = """{"green":158,"red":95,"name":"CadetBlue"}"""
noExplicitNullJson.parseJson.convertTo[Color] must throwA[DeserializationException]
}

locally {
given jp2: JsonProtocol = new DefaultJsonProtocol {
override def writeNulls = false
override def requireNullsForOptions = false
}
import jp2.*

case class Color(name: String, red: Int, green: Int, blue: Option[Int]) derives JsonFormat
val color = Color("CadetBlue", 95, 158, None)

val json = """{"green":158,"red":95,"name":"CadetBlue"}"""

color.toJson.compactPrint mustEqual json

color.toJson.convertTo[Color] mustEqual color
json.parseJson.convertTo[Color] mustEqual color
}
}
}
}
41 changes: 39 additions & 2 deletions besom-json/src/test/scala/besom/json/ProductFormatsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ class ProductFormatsSpec extends Specification {

case class Test0()
case class Test2(a: Int, b: Option[Double])
case class Test3[A, B](as: List[A], bs: List[B])
case class Test3[A, B](as: List[A], bs: Option[List[B]] = Some(List.empty))
case class Test4(t2: Test2)
case class Test5(optA: Option[String] = Some("default"))
case class TestTransient(a: Int, b: Option[Double]) {
@transient var c = false
}
Expand All @@ -37,13 +38,35 @@ class ProductFormatsSpec extends Specification {
implicit val test2Format: JsonFormat[Test2] = jsonFormatN[Test2]
implicit def test3Format[A: JsonFormat, B: JsonFormat]: RootJsonFormat[Test3[A, B]] = jsonFormatN[Test3[A, B]]
implicit def test4Format: JsonFormat[Test4] = jsonFormatN[Test4]
implicit def test5Format: JsonFormat[Test5] = jsonFormatN[Test5]
implicit def testTransientFormat: JsonFormat[TestTransient] = jsonFormatN[TestTransient]
implicit def testStaticFormat: JsonFormat[TestStatic] = jsonFormatN[TestStatic]
implicit def testMangledFormat: JsonFormat[TestMangled] = jsonFormatN[TestMangled]
}
object TestProtocol1 extends DefaultJsonProtocol with TestProtocol
object TestProtocol2 extends DefaultJsonProtocol with TestProtocol with NullOptions

case class Foo(a: Int, b: Int)
object Foo:
import DefaultJsonProtocol.*
given JsonFormat[Foo] = jsonFormatN

"A JsonFormat derived for an inner class" should {
"compile" in {

val compileErrors = scala.compiletime.testing.typeCheckErrors(
"""
class Test:
case class Foo(a: Int, b: Int)
object Foo:
import DefaultJsonProtocol.*
given JsonFormat[Foo] = jsonFormatN"""
)

compileErrors must beEmpty
}
}

"A JsonFormat created with `jsonFormat`, for a case class with 2 elements," should {
import TestProtocol1.*
val obj = Test2(42, Some(4.2))
Expand Down Expand Up @@ -96,7 +119,7 @@ class ProductFormatsSpec extends Specification {

"A JsonFormat for a generic case class and created with `jsonFormat`" should {
import TestProtocol1.*
val obj = Test3(42 :: 43 :: Nil, "x" :: "y" :: "z" :: Nil)
val obj = Test3(42 :: 43 :: Nil, Some("x" :: "y" :: "z" :: Nil))
val json = JsObject(
"as" -> JsArray(JsNumber(42), JsNumber(43)),
"bs" -> JsArray(JsString("x"), JsString("y"), JsString("z"))
Expand Down Expand Up @@ -170,6 +193,20 @@ class ProductFormatsSpec extends Specification {
}
}

"A JsonFormat for a case class with default parameters and created with `jsonFormat`" should {
"read case classes with optional members from JSON with missing fields" in {
import TestProtocol1.*
JsObject().convertTo[Test5] mustEqual Test5(Some("default"))
}

"read a generic case class with optional members from JSON with missing fields" in {
import TestProtocol1.*
val json = JsObject("as" -> JsArray(JsNumber(23), JsNumber(5)))

json.convertTo[Test3[Int, String]] mustEqual Test3(List(23, 5), Some(List.empty))
}
}

"A JsonFormat for a case class with static fields and created with `jsonFormat`" should {
import TestProtocol1.*
val obj = TestStatic(42, Some(4.2))
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/besom/internal/Env.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ object Env:
private[internal] def getMaybe(key: String): Option[NonEmptyString] =
sys.env.get(key).flatMap(NonEmptyString(_))

import besom.json.*, DefaultJsonProtocol.*
import besom.json.*

given nesJF(using jfs: JsonFormat[String]): JsonFormat[NonEmptyString] =
new JsonFormat[NonEmptyString]:
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/besom/json/interpolator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package besom.json

export besom.util.JsonInterpolator.*
Loading

0 comments on commit 4b10f3f

Please sign in to comment.