Skip to content

Commit

Permalink
Merge pull request #212 from buildo/203-wrap_post_body_as
Browse files Browse the repository at this point in the history
#203: Wrap POST body as in wiro (closes #203)
  • Loading branch information
tpetrucciani authored Apr 22, 2020
2 parents 79fc705 + 80eec9a commit 37af865
Show file tree
Hide file tree
Showing 8 changed files with 358 additions and 134 deletions.
22 changes: 8 additions & 14 deletions tapiro/core/src/main/scala/io/buildo/tapiro/AkkaHttpMeta.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ object AkkaHttpMeta {
q"""
package ${`package`} {
..${imports.toList.map(i => q"import $i._")}
import akka.http.scaladsl.server._
import akka.http.scaladsl.server.Directives._
import io.circe.{ Decoder, Encoder }
import sttp.tapir.server.akkahttp._
import sttp.tapir.Codec.{ JsonCodec, PlainCodec }
import sttp.model.StatusCode
import akka.http.scaladsl.server._
import akka.http.scaladsl.server.Directives._

object $httpEndpointsName {
def routes[AuthToken](controller: $controllerName[AuthToken], statusCodes: String => StatusCode = _ => StatusCode.UnprocessableEntity)(..$implicits): Route = {
Expand All @@ -40,17 +41,10 @@ object AkkaHttpMeta {
q"pathPrefix($pathName) { List(..$rest).foldLeft[Route]($first)(_ ~ _) }"
}

val endpoints = (routes: List[Route]) =>
routes.flatMap { route =>
val name = Term.Name(route.name.last)
val endpointsName = Term.Select(Term.Name("endpoints"), name)
val controllersName = Term.Select(Term.Name("controller"), name)
val controllerContent =
if (route.params.length <= 1) Some(controllersName)
else Some(Term.Select(Term.Eta(controllersName), Term.Name("tupled")))
controllerContent.map { content =>
val toRoute = Term.Apply(Term.Select(endpointsName, Term.Name("toRoute")), List(content))
q"val ${Pat.Var(name)} = $toRoute"
}
val endpoints = (routes: List[TapiroRoute]) =>
routes.map { route =>
val name = Term.Name(route.route.name.last)
val endpointImpl = Meta.toEndpointImplementation(route)
q"val ${Pat.Var(name)} = endpoints.$name.toRoute($endpointImpl)"
}
}
20 changes: 7 additions & 13 deletions tapiro/core/src/main/scala/io/buildo/tapiro/Http4sMeta.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ object Http4sMeta {
import cats.effect._
import cats.implicits._
import cats.data.NonEmptyList
import io.circe.{ Decoder, Encoder }
import org.http4s._
import org.http4s.server.Router
import sttp.tapir.server.http4s._
Expand All @@ -41,20 +42,13 @@ object Http4sMeta {
val first = Term.Name(head.name.last)
val rest = tail.map(a => Term.Name(a.name.last))
val route: Lit.String = Lit.String("/" + pathName.value)
q"Router($route -> NonEmptyList($first, List(..$rest)).reduceK)"
q"Router($route -> NonEmptyList.of($first, ..$rest).reduceK)"
}

val endpoints = (routes: List[Route]) =>
routes.flatMap { route =>
val name = Term.Name(route.name.last)
val endpointsName = Term.Select(Term.Name("endpoints"), name)
val controllersName = Term.Select(Term.Name("controller"), name)
val controllerContent =
if (route.params.length <= 1) Some(controllersName)
else Some(Term.Select(Term.Eta(controllersName), Term.Name("tupled")))
controllerContent.map { content =>
val toRoutes = Term.Apply(Term.Select(endpointsName, Term.Name("toRoutes")), List(content))
q"val ${Pat.Var(name)} = $toRoutes"
}
val endpoints = (routes: List[TapiroRoute]) =>
routes.map { route =>
val name = Term.Name(route.route.name.last)
val endpointImpl = Meta.toEndpointImplementation(route)
q"val ${Pat.Var(name)} = endpoints.$name.toRoutes($endpointImpl)"
}
}
88 changes: 58 additions & 30 deletions tapiro/core/src/main/scala/io/buildo/tapiro/Meta.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,49 @@ package io.buildo.tapiro
import io.buildo.metarpheus.core.intermediate.{TaggedUnion, Type => MetarpheusType}

import scala.meta._
import scala.meta.contrib._

import cats.data.NonEmptyList

object Meta {
val codecsImplicits = (routes: List[TapiroRoute]) => {
val jsonCodecs = (routes.flatMap {
case TapiroRoute(route, error) =>
val params: List[MetarpheusType] = route.params.map(_.tpe)
((if (route.method == "post") params else Nil) ++
(error match {
case TapiroRouteError.OtherError(t) => List(t)
case _ => Nil
}) :+
route.returns)
}.distinct
.filter(t => typeNameString(t) != "Unit") //no json codec for Unit in tapir
.filter(t => typeNameString(t) != "String")
.filter(t => typeNameString(t) != "AuthToken")
.map(toScalametaType)
++ taggedUnionErrorMembers(routes))
.map(t => t"JsonCodec[$t]")
val plainCodecs = routes.flatMap {
case TapiroRoute(route, _) =>
(if (route.method == "get") route.params.map(_.tpe) else Nil) ++
route.params.map(_.tpe).filter(typeNameString(_) == "AuthToken")
}.distinct.map(t => t"PlainCodec[${toScalametaType(t)}]")
val codecs = jsonCodecs ++ plainCodecs
codecs.zipWithIndex.map(toImplicitParam.tupled)
val notUnit = (t: MetarpheusType) => t != MetarpheusType.Name("Unit")
val toDecoder = (t: Type) => t"Decoder[$t]"
val toEncoder = (t: Type) => t"Encoder[$t]"
val toJsonCodec = (t: Type) => t"JsonCodec[$t]"
val toPlainCodec = (t: Type) => t"PlainCodec[$t]"
val routeRequiredImplicits = (route: TapiroRoute) => {
val (authParamTypes, nonAuthParamTypes) =
route.route.params.map(_.tpe).partition(isAuthToken)
val inputImplicits =
route.method match {
case RouteMethod.GET =>
nonAuthParamTypes.map(toScalametaType).map(toPlainCodec)
case RouteMethod.POST =>
nonAuthParamTypes.map(toScalametaType).flatMap(t => List(toDecoder(t), toEncoder(t)))
}
val outputImplicits =
List(route.route.returns).filter(notUnit).map(toScalametaType).map(toJsonCodec)
val errorImplicits =
route.error match {
case RouteError.TaggedUnionError(tu) =>
tu.values.map(taggedUnionMemberType(tu)).map(toJsonCodec)
case RouteError.OtherError(t) =>
List(t).filter(notUnit).map(toScalametaType).map(toJsonCodec)
}
val authImplicits = authParamTypes.map(toScalametaType).map(toPlainCodec)
inputImplicits ++ outputImplicits ++ errorImplicits ++ authImplicits
}
deduplicate(routes.flatMap(routeRequiredImplicits)).zipWithIndex.map(toImplicitParam.tupled)
}

private[this] val taggedUnionErrorMembers = (routes: List[TapiroRoute]) => {
val taggedUnions = routes.collect {
case TapiroRoute(_, TapiroRouteError.TaggedUnionError(tu)) => tu
}.distinct
taggedUnions.flatMap { taggedUnion =>
taggedUnion.values.map(taggedUnionMemberType(taggedUnion))
private[this] val deduplicate: List[Type] => List[Type] = (ts: List[Type]) =>
ts match {
case Nil => Nil
case head :: tail => head :: deduplicate(tail.filter(!_.isEqual(head)))
}
}

private[this] val isAuthToken = (t: MetarpheusType) => t == MetarpheusType.Name("AuthToken")

private[this] val toImplicitParam = (paramType: Type, index: Int) => {
val paramName = Term.Name(s"codec$index")
Expand Down Expand Up @@ -71,4 +76,27 @@ object Meta {
def packageFromList(`package`: NonEmptyList[String]): Term.Ref =
`package`.tail
.foldLeft[Term.Ref](Term.Name(`package`.head))((acc, n) => Term.Select(acc, Term.Name(n)))

val toEndpointImplementation = (route: TapiroRoute) => {
val name = Term.Name(route.route.name.last)
val controllersName = q"controller.$name"
route.method match {
case RouteMethod.GET =>
route.route.params.length match {
case 0 => q"_ => $controllersName()"
case 1 => controllersName
case _ => q"($controllersName _).tupled"
}
case RouteMethod.POST =>
val fields = route.route.params
.filterNot(_.tpe == MetarpheusType.Name("AuthToken"))
.map(p => Term.Name(p.name.getOrElse(Meta.typeNameString(p.tpe))))
val hasAuth = route.route.params
.exists(_.tpe == MetarpheusType.Name("AuthToken"))
if (hasAuth)
q"{ case (x, token) => $controllersName(..${fields.map(f => q"x.$f")}, token) }"
else
q"x => $controllersName(..${fields.map(f => q"x.$f")})"
}
}
}
19 changes: 15 additions & 4 deletions tapiro/core/src/main/scala/io/buildo/tapiro/MetarpheusHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@ package io.buildo.tapiro
import io.buildo.metarpheus.core.intermediate.{Type => MetarpheusType, Model, TaggedUnion, Route}

object MetarpheusHelper {
def routeError(route: Route, models: List[Model]): TapiroRouteError =
def toTapiroRoute(models: List[Model])(route: Route): TapiroRoute =
TapiroRoute(
route = route,
method = route.method match {
case "get" => RouteMethod.GET
case "post" => RouteMethod.POST
case _ => throw new Exception("method not supported")
},
error = routeError(route, models),
)

def routeError(route: Route, models: List[Model]): RouteError =
route.error.map { error =>
val errorName = error match {
case MetarpheusType.Name(name) => name
Expand All @@ -16,7 +27,7 @@ object MetarpheusHelper {
if (candidates.length > 1) throw new Exception(s"ambiguous error type name $errorName")
else
candidates.headOption
.map(TapiroRouteError.TaggedUnionError.apply)
.getOrElse(TapiroRouteError.OtherError(error))
}.getOrElse(TapiroRouteError.OtherError(MetarpheusType.Name("String")))
.map(RouteError.TaggedUnionError.apply)
.getOrElse(RouteError.OtherError(error))
}.getOrElse(RouteError.OtherError(MetarpheusType.Name("String")))
}
Loading

0 comments on commit 37af865

Please sign in to comment.