diff --git a/modules/parser-gen/src/main/scala/playground/parsergen/IR.scala b/modules/parser-gen/src/main/scala/playground/parsergen/IR.scala index c3c1876e..b2105c18 100644 --- a/modules/parser-gen/src/main/scala/playground/parsergen/IR.scala +++ b/modules/parser-gen/src/main/scala/playground/parsergen/IR.scala @@ -1,14 +1,18 @@ package playground.parsergen import cats.data.NonEmptyList +import treesittersmithy.FieldName import treesittersmithy.NodeType import treesittersmithy.TypeName enum Type { - case ADT(name: TypeName, subtypes: NonEmptyList[Subtype]) + case Product(name: TypeName, fields: List[Field], children: Option[Children]) } +case class Field(name: FieldName, targetTypes: NonEmptyList[TypeName], repeated: Boolean) +case class Children(targetTypes: NonEmptyList[TypeName], repeated: Boolean) + case class Subtype(name: TypeName) object IR { @@ -16,11 +20,31 @@ object IR { def from(nt: NodeType): Type = if nt.subtypes.nonEmpty then fromADT(nt) else - sys.error("todo") + fromProduct(nt) private def fromADT(nt: NodeType): Type.ADT = Type.ADT( name = nt.tpe, subtypes = NonEmptyList.fromListUnsafe(nt.subtypes.map(subtype => Subtype(name = subtype.tpe))), ) + private def fromProduct(nt: NodeType): Type.Product = Type.Product( + name = nt.tpe, + fields = + nt.fields + .map { (fieldName, fieldInfo) => + Field( + name = fieldName, + targetTypes = NonEmptyList.fromListUnsafe(fieldInfo.types.map(_.tpe)), + repeated = fieldInfo.multiple, + ) + } + .toList, + children = nt.children.map { children => + Children( + targetTypes = NonEmptyList.fromListUnsafe(children.types.map(_.tpe)), + repeated = children.multiple, + ) + }, + ) + } diff --git a/modules/parser-gen/src/main/scala/playground/parsergen/ParserGen.scala b/modules/parser-gen/src/main/scala/playground/parsergen/ParserGen.scala index 962cbfd1..bf9140e4 100644 --- a/modules/parser-gen/src/main/scala/playground/parsergen/ParserGen.scala +++ b/modules/parser-gen/src/main/scala/playground/parsergen/ParserGen.scala @@ -1,5 +1,6 @@ package playground.parsergen +import cats.data.NonEmptyList import cats.syntax.all.* import monocle.syntax.all.* import org.polyvariant.treesitter4s.Node @@ -32,9 +33,10 @@ extension (fn: FieldName) { extension (tpe: NodeType) { def render: String = - if tpe.subtypes.nonEmpty then renderAdt(IR.from(tpe).asInstanceOf[Type.ADT]) - else - renderClass(tpe) + IR.from(tpe) match { + case adt: Type.ADT => renderAdt(adt) + case product: Type.Product => renderProduct(product) + } } @@ -78,98 +80,88 @@ private def renderAdt(adt: Type.ADT): String = { |""".stripMargin } -private def renderClass(tpe: NodeType): String = { - val name = tpe.tpe.render +private def renderProduct(p: Type.Product): String = { + val name = p.name.render - val fieldGetters = tpe - .fields - .toList - .map { (k, fieldType) => - val typeUnion = fieldType - .types - .map(tpe => show"${tpe.tpe.render}") - .reduceLeftOption(_ + " | " + _) - .getOrElse(sys.error(s"unexpected empty list of types: $k (in ${tpe.tpe})")) - - val fieldTypeAnnotation = typeUnion.pipe { - case s if fieldType.multiple => show"List[$s]" - case s => show"Option[$s]" - } + def renderTypeUnion(types: NonEmptyList[TypeName]) = types + .map(_.render) + .reduceLeft(_ + " | " + _) - val allFields = show"""node.fields.getOrElse(${k.value.literal}, Nil)""" + def renderFieldType(field: Field): String = renderTypeUnion(field.targetTypes).pipe { + case s if field.repeated => show"List[$s]" + case s => show"Option[$s]" + } + + def renderChildrenType(children: Children): String = renderTypeUnion(children.targetTypes).pipe { + case s if children.repeated => show"List[$s]" + case s => show"Option[$s]" + } - val cases = fieldType.types.map { typeInfo => - show"""case ${typeInfo.tpe.render}(node) => node""" + def renderChildType(tpe: TypeName, repeated: Boolean): String = tpe.render.pipe { + case s if repeated => show"List[$s]" + case s => show"Option[$s]" + } + + val fieldGetters = p + .fields + .map { field => + val allFields = show"""node.fields.getOrElse(${field.name.value.literal}, Nil)""" + + val cases = field.targetTypes.map { tpe => + show"""case ${tpe.render}(node) => node""" } + val fieldValue = - if fieldType.multiple then show"""$allFields.toList.collect { - |${cases.mkString("\n").indentTrim(2)} - |}""".stripMargin + if field.repeated then show"""$allFields.toList.collect { + |${cases.mkString_("\n").indentTrim(2)} + |}""".stripMargin else show"""$allFields.headOption.map { - |${cases.mkString("\n").indentTrim(2)} + |${cases.mkString_("\n").indentTrim(2)} |}""".stripMargin - show"""def ${k.render}: ${fieldTypeAnnotation} = $fieldValue""" + show"""def ${field.name.render}: ${renderFieldType(field)} = $fieldValue""" } - val typedChildren = tpe.children.map { fieldType => - val typeUnion = fieldType - .types - .map(tpe => show"${tpe.tpe.render}") - .reduceLeftOption(_ + " | " + _) - .getOrElse(sys.error(s"unexpected empty list of types in children: (in ${tpe.tpe})")) - - val fieldTypeAnnotation = typeUnion.pipe { - case s if fieldType.multiple => show"List[$s]" - case s => show"Option[$s]" - } + val typedChildren = p.children.map { children => + val fieldTypeAnnotation = renderChildrenType(children) val allChildren = show"""node.children""" - val cases = fieldType.types.map { typeInfo => - show"""case ${typeInfo.tpe.render}(node) => node""" + val cases = children.targetTypes.map { tpe => + show"""case ${tpe.render}(node) => node""" } val fieldValue = - if fieldType.multiple then show"""$allChildren.toList.collect { - |${cases.mkString("\n").indentTrim(2)} - |}""".stripMargin + if children.repeated then show"""$allChildren.toList.collect { + |${cases.mkString_("\n").indentTrim(2)} + |}""".stripMargin else show"""$allChildren.collectFirst { - |${cases.mkString("\n").indentTrim(2)} + |${cases.mkString_("\n").indentTrim(2)} |}""".stripMargin show"""def typedChildren: ${fieldTypeAnnotation} = $fieldValue""" } - val typedChildrenPrecise = tpe + val typedChildrenPrecise = p .children .toList .flatMap { fieldInfo => - fieldInfo.types.map((fieldInfo.multiple, _)) + fieldInfo.targetTypes.map((fieldInfo.repeated, _)).toList } - .map { (multiple, fieldType) => - val fieldTypeAnnotation = fieldType.tpe.render.pipe { - case s if multiple => show"List[$s]" - case s => show"Option[$s]" - } - + .map { (repeated, fieldType) => + val fieldTypeAnnotation = renderChildType(fieldType, repeated) val childValue = - if multiple then show"""node.children.toList.collect { - | case ${fieldType - .tpe - .render}(node) => node + if repeated then show"""node.children.toList.collect { + | case ${fieldType.render}(node) => node |}""".stripMargin else show"""node.children.collectFirst { - | case ${fieldType.tpe.render}(node) => node + | case ${fieldType.render}(node) => node |}""".stripMargin - show"""def ${fieldType - .tpe - .asChildName - .render}: $fieldTypeAnnotation = $childValue""".stripMargin + show"""def ${fieldType.asChildName.render}: $fieldTypeAnnotation = $childValue""".stripMargin } val methods = @@ -196,9 +188,9 @@ private def renderClass(tpe: NodeType): String = { |${methods.indentTrim(2)} | | def apply(node: Node): Either[String, $name] = - | if node.tpe == ${tpe.tpe.value.literal} + | if node.tpe == ${p.name.value.literal} | then Right(node) - | else Left(s"Expected ${tpe.tpe.render}, got $${node.tpe}") + | else Left(s"Expected ${p.name.render}, got $${node.tpe}") | def unsafeApply(node: Node): $name = apply(node).fold(sys.error, identity) | def unapply(node: Node): Option[$name] = apply(node).toOption |}