Skip to content

Commit

Permalink
Merge pull request #326 from VirtusLab/issue-215
Browse files Browse the repository at this point in the history
  • Loading branch information
lbialy authored Jul 24, 2024
2 parents a58465a + dd31282 commit 5bf61d1
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.virtuslab.yaml
import org.virtuslab.yaml.Node.*

import scala.compiletime.*
import scala.quoted.*
import scala.deriving.Mirror

private[yaml] trait YamlDecoderCompanionCrossCompat extends DecoderMacros {
Expand All @@ -12,70 +13,11 @@ private[yaml] trait YamlDecoderCompanionCrossCompat extends DecoderMacros {
}

private[yaml] trait DecoderMacros {
protected def extractKeyValues(
mappings: Map[Node, Node]
): Either[ConstructError, Map[String, Node]] = {
val keyValueMap = mappings
.map { (k, v) =>
k match {
case ScalarNode(scalarKey, _) => Right((scalarKey, v))
case node =>
Left(ConstructError.from(s"Parameter of a class must be a scalar value", node))
}
}
val (error, valuesSeq) = keyValueMap.partitionMap(identity)

if (error.nonEmpty) Left(error.head)
else Right(valuesSeq.toMap)
}

protected def constructValues[T](
elemLabels: List[String],
instances: List[YamlDecoder[_]],
optionalTypes: List[Boolean],
valuesMap: Map[String, Node],
p: Mirror.ProductOf[T],
parentNode: Node
) = {
val values = elemLabels.zip(instances).zip(optionalTypes).map { case ((label, c), isOptional) =>
valuesMap.get(label) match
case Some(value) => c.construct(value)
case None =>
if (isOptional) Right(None)
else Left(ConstructError.from(s"Key $label doesn't exist in parsed document", parentNode))
}
val (left, right) = values.partitionMap(identity)
if left.nonEmpty then Left(left.head)
else Right(p.fromProduct(Tuple.fromArray(right.toArray)))
protected inline def deriveProduct[T](p: Mirror.ProductOf[T]) = ${
DecoderMacros.deriveProductImpl[T]('p)
}

protected inline def deriveProduct[T](p: Mirror.ProductOf[T]) =
val instances = summonAll[p.MirroredElemTypes]
val elemLabels = getElemLabels[p.MirroredElemLabels]
val optionalTypes = getOptionalTypes[p.MirroredElemTypes]
new YamlDecoder[T] {
override def construct(node: Node)(using
constructor: LoadSettings = LoadSettings.empty
): Either[ConstructError, T] =
node match
case Node.MappingNode(mappings, _) =>
for {
valuesMap <- extractKeyValues(mappings)
constructedValues <- constructValues(
elemLabels,
instances,
optionalTypes,
valuesMap,
p,
node
)
} yield (constructedValues)
case _ =>
Left(
ConstructError.from(s"Expected MappingNode, got ${node.getClass.getSimpleName}", node)
)
}

protected inline def sumOf[T](s: Mirror.SumOf[T]) =
val instances = summonSumOf[s.MirroredElemTypes].asInstanceOf[List[YamlDecoder[T]]]
new YamlDecoder[T]:
Expand All @@ -94,17 +36,164 @@ private[yaml] trait DecoderMacros {
}
case _: EmptyTuple => Nil

protected inline def summonAll[T <: Tuple]: List[YamlDecoder[_]] = inline erasedValue[T] match
case _: EmptyTuple => Nil
case _: (t *: ts) => summonInline[YamlDecoder[t]] :: summonAll[ts]
}

object DecoderMacros {

protected def constructValues[T](
instances: List[(String, YamlDecoder[?], Boolean)],
valuesMap: Map[String, Node],
defaultParams: Map[String, () => Any],
p: Mirror.ProductOf[T],
parentNode: Node
): Either[ConstructError, T] = {
val values = instances.map { case (label, c, isOptional) =>
valuesMap.get(label) match
case Some(value) => c.construct(value)
case None =>
if (isOptional) Right(None)
else if (defaultParams.contains(label))
val defaultParamCreator = defaultParams(label)
val defaultParamValue = defaultParamCreator()
Right(defaultParamValue)
else Left(ConstructError.from(s"Key $label doesn't exist in parsed document", parentNode))
}
val (left, right) = values.partitionMap(identity)
if left.nonEmpty then Left(left.head)
else Right(p.fromProduct(Tuple.fromArray(right.toArray)))
}

private def extractKeyValues(
mappings: Map[Node, Node]
): Either[ConstructError, Map[String, Node]] = {
val keyValueMap = mappings
.map { (k, v) =>
k match {
case ScalarNode(scalarKey, _) => Right((scalarKey, v))
case node =>
Left(ConstructError.from(s"Parameter of a class must be a scalar value", node))
}
}
val (error, valuesSeq) = keyValueMap.partitionMap(identity)

if (error.nonEmpty) Left(error.head)
else Right(valuesSeq.toMap)
}

def deriveProductImpl[T: Type](p: Expr[Mirror.ProductOf[T]])(using
Quotes
): Expr[YamlDecoder[T]] =

// returns a list of tuples of label, instance, isOptional
def prepareInstances(
elemLabels: Type[?],
elemTypes: Type[?]
): List[Expr[(String, YamlDecoder[?], Boolean)]] =
(elemLabels, elemTypes) match
case ('[EmptyTuple], '[EmptyTuple]) => Nil
case ('[label *: labelsTail], '[tpe *: tpesTail]) =>
val label = Type.valueOfConstant[label].get.asInstanceOf[String]
val isOption = Type.of[tpe] match
case '[Option[?]] => Expr(true)
case _ => Expr(false)

val fieldName = Expr(label)
val fieldFormat = Expr.summon[YamlDecoder[tpe]].getOrElse {
quotes.reflect.report
.errorAndAbort("Missing given instance of YamlDecoder[" ++ Type.show[tpe] ++ "]")
}
val namedInstance = '{ (${ fieldName }, $fieldFormat, ${ isOption }) }
namedInstance :: prepareInstances(Type.of[labelsTail], Type.of[tpesTail])

p match
case '{
$m: Mirror.ProductOf[T] {
type MirroredElemLabels = elementLabels; type MirroredElemTypes = elementTypes
}
} =>
val allInstancesExpr =
Expr.ofList(prepareInstances(Type.of[elementLabels], Type.of[elementTypes]))
val defaultParamsExpr = findDefaultParams[T]

protected inline def getElemLabels[T <: Tuple]: List[String] = inline erasedValue[T] match
case _: EmptyTuple => Nil
case _: (head *: tail) => constValue[head].toString :: getElemLabels[tail]
'{
new YamlDecoder[T] {
private val allInstances = $allInstancesExpr
private val defaultParams = $defaultParamsExpr
private val mirror = $p

protected inline def getOptionalTypes[T <: Tuple]: List[Boolean] = inline erasedValue[T] match
case _: EmptyTuple => Nil
case _: (Option[_] *: tail) => true :: getOptionalTypes[tail]
case _: (_ *: tail) => false :: getOptionalTypes[tail]
override def construct(node: Node)(using
constructor: LoadSettings = LoadSettings.empty
): Either[ConstructError, T] =
node match
case Node.MappingNode(mappings, _) =>
for {
valuesMap <- extractKeyValues(mappings)
constructedValues <- constructValues(
allInstances,
valuesMap,
defaultParams,
mirror,
node
)
} yield (constructedValues)
case _ =>
Left(
ConstructError.from(
s"Expected MappingNode, got ${node.getClass.getSimpleName}",
node
)
)
}
}

private val DefaultParamPrefix = "$lessinit$greater$default$"

protected def findDefaultParams[T](using
quotes: Quotes,
tpe: Type[T]
): Expr[Map[String, () => Any]] =
import quotes.reflect.*

TypeRepr.of[T].classSymbol match
case None => '{ Map.empty[String, () => Any] }
case Some(sym: Symbol) =>
try
val comp = sym.companionClass
val mod = Ref(sym.companionModule)
val paramWithDefaultMeta =
for (p, idx) <- sym.caseFields.zipWithIndex if p.flags.is(Flags.HasDefault)
// +1 because the names are generated starting from 1
yield (p.name, idx + 1)

val idents: List[(String, Ref)] =
for (paramName, idx) <- paramWithDefaultMeta
yield paramName -> mod.select(
// head is safe here because we know there has to be a getter for the default value
// because we checked for HasDefault flag
comp.methodMember(DefaultParamPrefix + idx.toString).head
)

val typeArgs = TypeRepr.of[T].typeArgs

// we create an expression of a list of tuples of name and thunks that return the default value for a given parameter
val defaultsThunksExpr: Expr[List[(String, () => Any)]] =
if typeArgs.isEmpty then
Expr.ofList(
idents.map { case (name, ref) => name -> ref.asExpr }.map { case (name, '{ $x }) =>
'{ (${ Expr(name) }, () => $x) }
}
)
else // if there are type parameters, we need to apply the type parameters to accessors
Expr.ofList(
idents.map { case (name, ref) => name -> ref.appliedToTypes(typeArgs).asExpr }.map {
case (name, '{ $x }) => '{ (${ Expr(name) }, () => $x) }
}
)

'{ $defaultsThunksExpr.toMap }
catch // TODO drop after https://github.com/lampepfl/dotty/issues/19732 (after bump to 3.3.4)
case cce: ClassCastException =>
'{
Map.empty[String, () => Any]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -502,3 +502,66 @@ class DecoderSuite extends munit.FunSuite:
assert(error.msg.contains("Could't construct int from null (tag:yaml.org,2002:null)"))
case Right(data) => fail(s"expected failure, but got: $data")
}

test("default parameters for case classes can be used when decoding") {
case class Foo(a: Int = 1, b: String = "test", c: Option[Int] = None, d: Double)
derives YamlCodec

val yaml = """d: 1.0""".stripMargin

yaml.as[Foo] match
case Left(error: YamlError) =>
fail(s"failed with YamlError: $error")
case Right(foo) =>
assertEquals(foo.a, 1)
assertEquals(foo.b, "test")
assertEquals(foo.c, None)
assertEquals(foo.d, 1.0)
}

test("default parameters for case classes are evaluated lazily") {
var times = 0
def createB = {
times += 1
s"test-${times}"
}
case class Foo(a: Int, b: String = createB) derives YamlCodec

val yaml = """a: 1""".stripMargin

yaml.as[Foo] match
case Left(error: YamlError) =>
fail(s"failed with YamlError: $error")
case Right(foo) =>
assertEquals(foo.a, 1)
assertEquals(foo.b, "test-1")

yaml.as[Foo] // skip test-2

yaml.as[Foo] match
case Left(error: YamlError) =>
fail(s"failed with YamlError: $error")
case Right(foo) =>
assertEquals(foo.a, 1)
assertEquals(foo.b, "test-3")
}

test("default parameters are not evaluated when they are provided in yaml") {
var evaluated = false
def createB = {
evaluated = true
"default"
}
case class Foo(a: Int, b: String = createB) derives YamlCodec

val yaml = """a: 1
|b: from yaml""".stripMargin

yaml.as[Foo] match
case Left(error: YamlError) =>
fail(s"failed with YamlError: $error")
case Right(foo) =>
assertEquals(foo.a, 1)
assertEquals(foo.b, "from yaml")
assert(!evaluated)
}

0 comments on commit 5bf61d1

Please sign in to comment.