Skip to content

Commit

Permalink
initial work on directives
Browse files Browse the repository at this point in the history
  • Loading branch information
ValdemarGr committed Feb 5, 2024
1 parent 630a19e commit 4f4b41c
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 36 deletions.
22 changes: 19 additions & 3 deletions modules/core/src/main/scala/gql/Directive.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,22 @@ import gql.preparation.MergedFieldInfo
import gql.parser.QueryAst

/** A [[Directive]] takes an argument A and performs some context specific ast transformation.
*
* The [[Directive]] structure defines executable directives that modify execution https://spec.graphql.org/draft/#ExecutableDirectiveLocation.
*/
final case class Directive[A](
name: String,
isRepeatable: Boolean,
arg: EmptyableArg[A] = EmptyableArg.Empty
)
) {
def repeatable = copy(isRepeatable = true)
def unrepeatable = copy(isRepeatable = false)
}

/** Consider taking a look at the skip and include directives as an example.
*/
object Directive {
val skipDirective = Directive("skip", EmptyableArg.Lift(gql.dsl.input.arg[Boolean]("if")))
val skipDirective = Directive("skip", false, EmptyableArg.Lift(gql.dsl.input.arg[Boolean]("if")))

def skipPositions[F[_]]: List[Position[F, ?]] = {
val field = Position.Field(
Expand Down Expand Up @@ -61,7 +67,7 @@ object Directive {
List(field, fragmentSpread, inlineFragmentSpread)
}

val includeDirective = Directive("include", EmptyableArg.Lift(gql.dsl.input.arg[Boolean]("if")))
val includeDirective = Directive("include", false, EmptyableArg.Lift(gql.dsl.input.arg[Boolean]("if")))

def includePositions[F[_]]: List[Position[F, ?]] = {
val field = Position.Field(
Expand Down Expand Up @@ -122,4 +128,14 @@ object Position {
directive: Directive[A],
handler: QueryHandler[QA.InlineFragment, A]
) extends Position[Nothing, A]

final case class Enum[A](
directive: Directive[A],
handler: Enum.Handler[A]
) extends Position[Nothing, A]
object Enum {
trait Handler[A] {
def apply[B](a: A, e: ast.Enum[B]): Either[String, ast.Enum[B]]
}
}
}
5 changes: 3 additions & 2 deletions modules/core/src/main/scala/gql/SchemaShape.scala
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ object SchemaShape {
case ii: TypeInfo.InInfo =>
ii.t match {
case Scalar(_, _, _, _) => __TypeKind.SCALAR
case Enum(_, _, _) => __TypeKind.ENUM
case Enum(_, _, _, _) => __TypeKind.ENUM
case _: Input[?] => __TypeKind.INPUT_OBJECT
}
},
Expand Down Expand Up @@ -646,7 +646,7 @@ object SchemaShape {
case _ => None
},
"enumValues" -> lift(inclDeprecated) { case (_, ti) =>
ti.asToplevel.collect { case Enum(_, m, _) => m.toList.map { case (k, v) => NamedEnumValue(k, v) } }
ti.asToplevel.collect { case Enum(_, m, _, _) => m.toList.map { case (k, v) => NamedEnumValue(k, v) } }
},
"inputFields" -> lift(inclDeprecated) {
case (_, ii: TypeInfo.InInfo) =>
Expand Down Expand Up @@ -705,6 +705,7 @@ object SchemaShape {
case Position.Field(_, _) => DirectiveLocation.FIELD
case Position.FragmentSpread(_, _) => DirectiveLocation.FRAGMENT_SPREAD
case Position.InlineFragmentSpread(_, _) => DirectiveLocation.INLINE_FRAGMENT
case Position.Enum(_, _) => DirectiveLocation.ENUM
}
zs
},
Expand Down
4 changes: 2 additions & 2 deletions modules/core/src/main/scala/gql/Validation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ object Validation {
validateTypeName[F, G](name) *>
validateArg[F, G](fields, discovery)
}
case Enum(name, _, _) => validateTypeName[F, G](name)
case Enum(name, _, _, _) => validateTypeName[F, G](name)
case Scalar(name, _, _, _) => validateTypeName[F, G](name)
}

Expand Down Expand Up @@ -480,7 +480,7 @@ object Validation {
G: Monad[G]
): G[Unit] =
tl match {
case Enum(_, _, _) => G.unit
case Enum(_, _, _, _) => G.unit
case Scalar(_, _, _, _) => G.unit
case s: Selectable[F, ?] => validateToplevel[F, G](s, discovery)
case OutArr(of, _, _) => validateOutput[F, G](of, discovery)
Expand Down
8 changes: 8 additions & 0 deletions modules/core/src/main/scala/gql/ast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ object ast extends AstImplicits.Implicits {
final case class Enum[A](
name: String,
mappings: NonEmptyList[(String, EnumValue[? <: A])],
directives: List[SchemaDirective[Nothing, Position.Enum]] = Nil,
description: Option[String] = None
) extends OutToplevel[fs2.Pure, A]
with InToplevel[A] {
Expand Down Expand Up @@ -287,6 +288,13 @@ object ast extends AstImplicits.Implicits {
trait IDLowPrio {
implicit def idIn[A](implicit s: Scalar[A]): In[ID[A]] = ID.idTpe[A]
}


// TypeSystemDirectiveLocation
final case class SchemaDirective[+F[_], P[x] <: Position[F, x]](
position: P[?],
args: List[(String, V[Const, Unit])]
)
}

object AstImplicits {
Expand Down
8 changes: 4 additions & 4 deletions modules/core/src/main/scala/gql/dsl/DirectiveDsl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ import gql._

trait DirectiveDsl[F[_]] {
def directive(name: String): Directive[Unit] =
Directive(name)
Directive(name, isRepeatable = false)

def directive[A](name: String, arg: Arg[A]): Directive[A] =
Directive(name, EmptyableArg.Lift(arg))
Directive(name, isRepeatable = false, EmptyableArg.Lift(arg))

def onField[A](directive: Directive[A], handler: Position.FieldHandler[F, A]): State[SchemaState[F], Position.Field[F, A]] =
DirectiveDsl.onField(directive, handler)
Expand All @@ -47,10 +47,10 @@ trait DirectiveDslFull {
State(s => (s.copy(positions = pos :: s.positions), pos))

def directive(name: String): Directive[Unit] =
Directive(name)
Directive(name, isRepeatable= false)

def directive[A](name: String, arg: Arg[A]): Directive[A] =
Directive(name, EmptyableArg.Lift(arg))
Directive(name, isRepeatable= false, EmptyableArg.Lift(arg))

def onField[F[_], A](directive: Directive[A], handler: Position.FieldHandler[F, A]): State[SchemaState[F], Position.Field[F, A]] =
addPosition[F, A, Position.Field[F, A]](Position.Field(directive, handler))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class ArgParsing[C](variables: VariableMap[C]) {
verifiedF *> verifiedTypenameF *> parseInnerF
}
}
case (e @ Enum(name, _, _), v) =>
case (e @ Enum(name, _, _, _), v) =>
val fa: G[(String, List[C])] = v match {
case V.EnumValue(s, cs) => G.pure((s, cs))
case V.StringValue(s, cs) if ambigiousEnum => G.pure((s, cs))
Expand Down
62 changes: 53 additions & 9 deletions modules/core/src/main/scala/gql/preparation/DirectiveAlg.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import gql.parser._
import cats._
import cats.implicits._
import gql._
import scala.reflect.ClassTag

class DirectiveAlg[F[_], C](
positions: Map[String, List[Position[F, ?]]],
Expand All @@ -27,19 +28,62 @@ class DirectiveAlg[F[_], C](
type G[A] = Alg[C, A]
val G = Alg.Ops[C]

def parseArg[P[x] <: Position[F, x], A](p: P[A], args: Option[QueryAst.Arguments[C, AnyValue]], context: List[C]): G[A] = {
def parseArg[P[x] <: Position[F, x], A](p: P[A], args: List[(String, Value[AnyValue, List[C]])], context: List[C]): G[A] = {
p.directive.arg match {
case EmptyableArg.Empty =>
args match {
case Some(_) => G.raise(s"Directive '${p.directive.name}' does not expect arguments", context)
case None => G.unit
}
case EmptyableArg.Lift(a) =>
val argFields = args.toList.flatMap(_.nel.toList).map(a => a.name -> a.value.map(List(_))).toMap
ap.decodeArg(a, argFields, ambigiousEnum = false, context)
case EmptyableArg.Lift(a) => ap.decodeArg(a, args.toMap, ambigiousEnum = false, context)
case EmptyableArg.Empty if args.isEmpty => G.unit
case EmptyableArg.Empty => G.raise(s"Directive '${p.directive.name}' does not expect arguments", context)
}
}

def getDirective[P[x] <: Position[F, x]](name: String, context: List[C])(pf: PartialFunction[Position[F, ?], P[?]]): Alg[C, P[?]] =
positions.get(name) match {
case None => G.raise(s"Couldn't find directive '$name'", context)
case Some(d) =>
val p = d.collectFirst(pf)
G.raiseOpt(p, s"Directive '$name' cannot appear here", context)
}

case class ParsedDirective[A, P[x] <: Position[F, x]](p: P[A], a: A)
def parseProvided[P[x] <: Position[F, x]](
directives: Option[QueryAst.Directives[C, AnyValue]],
context: List[C]
)(pf: PartialFunction[Position[F, ?], P[?]]): Alg[C, List[ParsedDirective[?, P]]] =
directives.map(_.nel.toList).getOrElse(Nil).parTraverse { d =>
getDirective[P](d.name, context)(pf).flatMap { p =>
// rigid type variable inference help
def go[A](p: P[A]): G[ParsedDirective[A,P]] =
parseArg(
p,
d.arguments.map(_.nel.toList).getOrElse(Nil).map(a => a.name -> a.value.map(List(_))),
context
).map(ParsedDirective(p, _))

go(p)
}
}

def parseProvidedSubtype[P[x] <: Position[F, x]](
directives: Option[QueryAst.Directives[C, AnyValue]],
context: List[C]
)(implicit CT: ClassTag[P[Any]]): Alg[C, List[ParsedDirective[?, P]]] = {
val pf: PartialFunction[Position[F, ?], P[?]] = PartialFunction
.fromFunction(identity[Position[F, ?]])
.andThen(x => CT.unapply(x))
.andThen { case Some(x) => x: P[?] }
parseProvided[P](directives, context)(pf)
}

def parseSchemaDirective[P[x] <: Position[F, x]](sd: ast.SchemaDirective[F, P], context: List[C]): G[ParsedDirective[?, P]] = {
// rigid type variable inference help
def go[A](p: P[A]): G[ParsedDirective[A, P]] =
parseArg(p, sd.args.map{ case (k, v) => k -> v.map(_ => List.empty[C])}, context).flatMap{ a =>
G.pure(ParsedDirective(p, a))
}

go(sd.position: P[?]).widen[ParsedDirective[?, P]]
}

def foldDirectives[P[x] <: Position[F, x]]: DirectiveAlg.PartiallyAppliedFold[F, C, P] =
new DirectiveAlg.PartiallyAppliedFold[F, C, P] {
override def apply[H[_]: Traverse, A](directives: Option[QueryAst.Directives[C, AnyValue]], context: List[C])(base: A)(
Expand Down
19 changes: 12 additions & 7 deletions modules/core/src/main/scala/gql/preparation/FieldCollection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,12 @@ class FieldCollection[F[_], C](
all
.collect { case QA.Selection.InlineFragmentSelection(f, c) => (c, f) }
.parFlatTraverse { case (caret, f) =>
da.foldDirectives[Position.InlineFragmentSpread](f.directives, List(caret))(f) {
case (f, p: Position.InlineFragmentSpread[a], d) =>
da.parseArg(p, d.arguments, List(caret)).map(p.handler(_, f)).flatMap(G.raiseEither(_, List(caret)))
}.map(_ tupleLeft caret)
da
.parseProvidedSubtype[Position.InlineFragmentSpread](f.directives, List(caret))
.flatMap(_.foldLeftM(List(f)) { case (fs, d) =>
fs.parFlatTraverse(f => G.raiseEither(d.p.handler(d.a, f), List(caret)))
})
.map(_.tupleLeft(caret))
}
.flatMap(_.parFlatTraverse { case (caret, f) =>
f.typeCondition.traverse(matchType(_, sel, caret)).map(_.getOrElse(sel)).flatMap { t =>
Expand All @@ -139,9 +141,12 @@ class FieldCollection[F[_], C](
val realFragments = all
.collect { case QA.Selection.FragmentSpreadSelection(f, c) => (c, f) }
.parFlatTraverse { case (caret, f) =>
da.foldDirectives[Position.FragmentSpread](f.directives, List(caret))(f) { case (f, p: Position.FragmentSpread[a], d) =>
da.parseArg(p, d.arguments, List(caret)).map(p.handler(_, f)).flatMap(G.raiseEither(_, List(caret)))
}.map(_ tupleLeft caret)
da
.parseProvidedSubtype[Position.FragmentSpread](f.directives, List(caret))
.flatMap(_.foldLeftM(List(f)) { case (fs, d) =>
fs.parFlatTraverse(f => G.raiseEither(d.p.handler(d.a, f), List(caret)))
})
.map(_.tupleLeft(caret))
}
.flatMap(_.parFlatTraverse { case (caret, f) =>
val fn = f.fragmentName
Expand Down
23 changes: 15 additions & 8 deletions modules/core/src/main/scala/gql/preparation/QueryPreparation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,15 @@ class QueryPreparation[F[_], C](
case (s: Selectable[F, a], Some(ss)) =>
liftK(prepareSelectable[A](s, ss).widen[Prepared[F, A]])
case (e: Enum[a], None) =>
liftK(nextNodeId).map(PreparedLeaf(_, e.name, x => Json.fromString(e.revm(x))))
val cs = List(fi.caret)
liftK {
e.directives
.parTraverse(da.parseSchemaDirective(_, cs))
.flatMap(_.foldLeftM(e){ case (e, d) => G.raiseEither(d.p.handler(d.a, e), cs) })
.flatMap{ e =>
nextNodeId.map(PreparedLeaf[F, a](_, e.name, x => Json.fromString(e.revm(x))))
}
}
case (s: Scalar[a], None) =>
import io.circe.syntax._
liftK(nextNodeId).map(PreparedLeaf(_, s.name, x => s.encoder(x).asJson))
Expand All @@ -156,13 +164,12 @@ class QueryPreparation[F[_], C](
currentTypename: String
): G[List[PreparedDataField[F, I, ?]]] = {
da
.foldDirectives[Position.Field[F, *]][List, (Field[F, I, ?], MergedFieldInfo[F, C])](fi.directives, List(fi.caret))(
(field, fi)
) { case ((f: Field[F, I, ?], fi), p: Position.Field[F, a], d) =>
da.parseArg(p, d.arguments, List(fi.caret))
.map(p.handler(_, f, fi))
.flatMap(G.raiseEither(_, List(fi.caret)))
}
.parseProvidedSubtype[Position.Field[F, *]](fi.directives, List(fi.caret))
.flatMap(_.foldLeftM[G, List[(Field[F, I, ?], MergedFieldInfo[F, C])]](List((field, fi))) { case (xs, prov) =>
xs.parFlatTraverse { case (f: Field[F, I, ?], fi) =>
G.raiseEither(prov.p.handler(prov.a, f, fi), List(fi.caret))
}
})
.flatMap(_.parTraverse { case (field: Field[F, I, o2], fi) =>
val rootUniqueName = UniqueEdgeCursor(s"${currentTypename}_${fi.name}")

Expand Down

0 comments on commit 4f4b41c

Please sign in to comment.