Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#203: Wrap POST body as in wiro (closes #203) #212

Merged
merged 14 commits into from
Apr 22, 2020
22 changes: 8 additions & 14 deletions tapiro/core/src/main/scala/io/buildo/tapiro/AkkaHttpMeta.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ object AkkaHttpMeta {
q"""
package ${`package`} {
..${imports.toList.map(i => q"import $i._")}
import akka.http.scaladsl.server._
import akka.http.scaladsl.server.Directives._
import io.circe.{ Decoder, Encoder }
import sttp.tapir.server.akkahttp._
import sttp.tapir.Codec.{ JsonCodec, PlainCodec }
import sttp.model.StatusCode
import akka.http.scaladsl.server._
import akka.http.scaladsl.server.Directives._

object $httpEndpointsName {
def routes[AuthToken](controller: $controllerName[AuthToken], statusCodes: String => StatusCode = _ => StatusCode.UnprocessableEntity)(..$implicits): Route = {
Expand All @@ -40,17 +41,10 @@ object AkkaHttpMeta {
q"pathPrefix($pathName) { List(..$rest).foldLeft[Route]($first)(_ ~ _) }"
}

val endpoints = (routes: List[Route]) =>
routes.flatMap { route =>
val name = Term.Name(route.name.last)
val endpointsName = Term.Select(Term.Name("endpoints"), name)
val controllersName = Term.Select(Term.Name("controller"), name)
val controllerContent =
if (route.params.length <= 1) Some(controllersName)
else Some(Term.Select(Term.Eta(controllersName), Term.Name("tupled")))
controllerContent.map { content =>
val toRoute = Term.Apply(Term.Select(endpointsName, Term.Name("toRoute")), List(content))
q"val ${Pat.Var(name)} = $toRoute"
}
val endpoints = (routes: List[TapiroRoute]) =>
routes.map { route =>
val name = Term.Name(route.route.name.last)
val endpointImpl = Meta.toEndpointImplementation(route)
q"val ${Pat.Var(name)} = endpoints.$name.toRoute($endpointImpl)"
}
}
20 changes: 7 additions & 13 deletions tapiro/core/src/main/scala/io/buildo/tapiro/Http4sMeta.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ object Http4sMeta {
import cats.effect._
import cats.implicits._
import cats.data.NonEmptyList
import io.circe.{ Decoder, Encoder }
import org.http4s._
import org.http4s.server.Router
import sttp.tapir.server.http4s._
Expand All @@ -41,20 +42,13 @@ object Http4sMeta {
val first = Term.Name(head.name.last)
val rest = tail.map(a => Term.Name(a.name.last))
val route: Lit.String = Lit.String("/" + pathName.value)
q"Router($route -> NonEmptyList($first, List(..$rest)).reduceK)"
q"Router($route -> NonEmptyList.of($first, ..$rest).reduceK)"
}

val endpoints = (routes: List[Route]) =>
routes.flatMap { route =>
val name = Term.Name(route.name.last)
val endpointsName = Term.Select(Term.Name("endpoints"), name)
val controllersName = Term.Select(Term.Name("controller"), name)
val controllerContent =
if (route.params.length <= 1) Some(controllersName)
else Some(Term.Select(Term.Eta(controllersName), Term.Name("tupled")))
controllerContent.map { content =>
val toRoutes = Term.Apply(Term.Select(endpointsName, Term.Name("toRoutes")), List(content))
q"val ${Pat.Var(name)} = $toRoutes"
}
val endpoints = (routes: List[TapiroRoute]) =>
routes.map { route =>
val name = Term.Name(route.route.name.last)
val endpointImpl = Meta.toEndpointImplementation(route)
q"val ${Pat.Var(name)} = endpoints.$name.toRoutes($endpointImpl)"
}
}
88 changes: 58 additions & 30 deletions tapiro/core/src/main/scala/io/buildo/tapiro/Meta.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,49 @@ package io.buildo.tapiro
import io.buildo.metarpheus.core.intermediate.{TaggedUnion, Type => MetarpheusType}

import scala.meta._
import scala.meta.contrib._

import cats.data.NonEmptyList

object Meta {
val codecsImplicits = (routes: List[TapiroRoute]) => {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is getting quite messy...

Part of the problem is that we're delaying the conversion to scalameta type for each type of codec because we want to call .distinct while we still have a list of metarpheus types, because it wouldn't work on scalameta types because == is by reference

Maybe we could use structural equality from scala.meta.contrib (https://scalameta.org/docs/trees/guide.html#compare-trees-for-equality) to clean this up a bit?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using isEqual is a good idea

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried to rewrite that function and I think it reads a bit better now

val jsonCodecs = (routes.flatMap {
case TapiroRoute(route, error) =>
val params: List[MetarpheusType] = route.params.map(_.tpe)
((if (route.method == "post") params else Nil) ++
(error match {
case TapiroRouteError.OtherError(t) => List(t)
case _ => Nil
}) :+
route.returns)
}.distinct
.filter(t => typeNameString(t) != "Unit") //no json codec for Unit in tapir
.filter(t => typeNameString(t) != "String")
.filter(t => typeNameString(t) != "AuthToken")
.map(toScalametaType)
++ taggedUnionErrorMembers(routes))
.map(t => t"JsonCodec[$t]")
val plainCodecs = routes.flatMap {
case TapiroRoute(route, _) =>
(if (route.method == "get") route.params.map(_.tpe) else Nil) ++
route.params.map(_.tpe).filter(typeNameString(_) == "AuthToken")
}.distinct.map(t => t"PlainCodec[${toScalametaType(t)}]")
val codecs = jsonCodecs ++ plainCodecs
codecs.zipWithIndex.map(toImplicitParam.tupled)
val notUnit = (t: MetarpheusType) => t != MetarpheusType.Name("Unit")
val toDecoder = (t: Type) => t"Decoder[$t]"
val toEncoder = (t: Type) => t"Encoder[$t]"
val toJsonCodec = (t: Type) => t"JsonCodec[$t]"
val toPlainCodec = (t: Type) => t"PlainCodec[$t]"
val routeRequiredImplicits = (route: TapiroRoute) => {
val (authParamTypes, nonAuthParamTypes) =
route.route.params.map(_.tpe).partition(isAuthToken)
val inputImplicits =
route.method match {
case RouteMethod.GET =>
nonAuthParamTypes.map(toScalametaType).map(toPlainCodec)
case RouteMethod.POST =>
nonAuthParamTypes.map(toScalametaType).flatMap(t => List(toDecoder(t), toEncoder(t)))
}
val outputImplicits =
List(route.route.returns).filter(notUnit).map(toScalametaType).map(toJsonCodec)
val errorImplicits =
route.error match {
case RouteError.TaggedUnionError(tu) =>
tu.values.map(taggedUnionMemberType(tu)).map(toJsonCodec)
case RouteError.OtherError(t) =>
List(t).filter(notUnit).map(toScalametaType).map(toJsonCodec)
}
val authImplicits = authParamTypes.map(toScalametaType).map(toPlainCodec)
inputImplicits ++ outputImplicits ++ errorImplicits ++ authImplicits
}
deduplicate(routes.flatMap(routeRequiredImplicits)).zipWithIndex.map(toImplicitParam.tupled)
}

private[this] val taggedUnionErrorMembers = (routes: List[TapiroRoute]) => {
val taggedUnions = routes.collect {
case TapiroRoute(_, TapiroRouteError.TaggedUnionError(tu)) => tu
}.distinct
taggedUnions.flatMap { taggedUnion =>
taggedUnion.values.map(taggedUnionMemberType(taggedUnion))
private[this] val deduplicate: List[Type] => List[Type] = (ts: List[Type]) =>
ts match {
case Nil => Nil
case head :: tail => head :: deduplicate(tail.filter(!_.isEqual(head)))
}
}

private[this] val isAuthToken = (t: MetarpheusType) => t == MetarpheusType.Name("AuthToken")

private[this] val toImplicitParam = (paramType: Type, index: Int) => {
val paramName = Term.Name(s"codec$index")
Expand Down Expand Up @@ -71,4 +76,27 @@ object Meta {
def packageFromList(`package`: NonEmptyList[String]): Term.Ref =
`package`.tail
.foldLeft[Term.Ref](Term.Name(`package`.head))((acc, n) => Term.Select(acc, Term.Name(n)))

val toEndpointImplementation = (route: TapiroRoute) => {
val name = Term.Name(route.route.name.last)
val controllersName = q"controller.$name"
route.method match {
case RouteMethod.GET =>
route.route.params.length match {
case 0 => q"_ => $controllersName()"
case 1 => controllersName
case _ => q"($controllersName _).tupled"
}
case RouteMethod.POST =>
val fields = route.route.params
.filterNot(_.tpe == MetarpheusType.Name("AuthToken"))
.map(p => Term.Name(p.name.getOrElse(Meta.typeNameString(p.tpe))))
val hasAuth = route.route.params
.exists(_.tpe == MetarpheusType.Name("AuthToken"))
if (hasAuth)
q"{ case (x, token) => $controllersName(..${fields.map(f => q"x.$f")}, token) }"
else
q"x => $controllersName(..${fields.map(f => q"x.$f")})"
}
}
}
19 changes: 15 additions & 4 deletions tapiro/core/src/main/scala/io/buildo/tapiro/MetarpheusHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@ package io.buildo.tapiro
import io.buildo.metarpheus.core.intermediate.{Type => MetarpheusType, Model, TaggedUnion, Route}

object MetarpheusHelper {
def routeError(route: Route, models: List[Model]): TapiroRouteError =
def toTapiroRoute(models: List[Model])(route: Route): TapiroRoute =
TapiroRoute(
route = route,
method = route.method match {
case "get" => RouteMethod.GET
case "post" => RouteMethod.POST
case _ => throw new Exception("method not supported")
},
error = routeError(route, models),
)

def routeError(route: Route, models: List[Model]): RouteError =
route.error.map { error =>
val errorName = error match {
case MetarpheusType.Name(name) => name
Expand All @@ -16,7 +27,7 @@ object MetarpheusHelper {
if (candidates.length > 1) throw new Exception(s"ambiguous error type name $errorName")
else
candidates.headOption
.map(TapiroRouteError.TaggedUnionError.apply)
.getOrElse(TapiroRouteError.OtherError(error))
}.getOrElse(TapiroRouteError.OtherError(MetarpheusType.Name("String")))
.map(RouteError.TaggedUnionError.apply)
.getOrElse(RouteError.OtherError(error))
}.getOrElse(RouteError.OtherError(MetarpheusType.Name("String")))
}
Loading