diff --git a/tapiro/core/src/main/scala/io/buildo/tapiro/AkkaHttpMeta.scala b/tapiro/core/src/main/scala/io/buildo/tapiro/AkkaHttpMeta.scala index d2e5ef92..62913377 100644 --- a/tapiro/core/src/main/scala/io/buildo/tapiro/AkkaHttpMeta.scala +++ b/tapiro/core/src/main/scala/io/buildo/tapiro/AkkaHttpMeta.scala @@ -19,11 +19,12 @@ object AkkaHttpMeta { q""" package ${`package`} { ..${imports.toList.map(i => q"import $i._")} + import akka.http.scaladsl.server._ + import akka.http.scaladsl.server.Directives._ + import io.circe.{ Decoder, Encoder } import sttp.tapir.server.akkahttp._ import sttp.tapir.Codec.{ JsonCodec, PlainCodec } import sttp.model.StatusCode - import akka.http.scaladsl.server._ - import akka.http.scaladsl.server.Directives._ object $httpEndpointsName { def routes[AuthToken](controller: $controllerName[AuthToken], statusCodes: String => StatusCode = _ => StatusCode.UnprocessableEntity)(..$implicits): Route = { @@ -40,17 +41,10 @@ object AkkaHttpMeta { q"pathPrefix($pathName) { List(..$rest).foldLeft[Route]($first)(_ ~ _) }" } - val endpoints = (routes: List[Route]) => - routes.flatMap { route => - val name = Term.Name(route.name.last) - val endpointsName = Term.Select(Term.Name("endpoints"), name) - val controllersName = Term.Select(Term.Name("controller"), name) - val controllerContent = - if (route.params.length <= 1) Some(controllersName) - else Some(Term.Select(Term.Eta(controllersName), Term.Name("tupled"))) - controllerContent.map { content => - val toRoute = Term.Apply(Term.Select(endpointsName, Term.Name("toRoute")), List(content)) - q"val ${Pat.Var(name)} = $toRoute" - } + val endpoints = (routes: List[TapiroRoute]) => + routes.map { route => + val name = Term.Name(route.route.name.last) + val endpointImpl = Meta.toEndpointImplementation(route) + q"val ${Pat.Var(name)} = endpoints.$name.toRoute($endpointImpl)" } } diff --git a/tapiro/core/src/main/scala/io/buildo/tapiro/Http4sMeta.scala b/tapiro/core/src/main/scala/io/buildo/tapiro/Http4sMeta.scala index 105b7bcf..fd7eba30 100644 --- a/tapiro/core/src/main/scala/io/buildo/tapiro/Http4sMeta.scala +++ b/tapiro/core/src/main/scala/io/buildo/tapiro/Http4sMeta.scala @@ -22,6 +22,7 @@ object Http4sMeta { import cats.effect._ import cats.implicits._ import cats.data.NonEmptyList + import io.circe.{ Decoder, Encoder } import org.http4s._ import org.http4s.server.Router import sttp.tapir.server.http4s._ @@ -41,20 +42,13 @@ object Http4sMeta { val first = Term.Name(head.name.last) val rest = tail.map(a => Term.Name(a.name.last)) val route: Lit.String = Lit.String("/" + pathName.value) - q"Router($route -> NonEmptyList($first, List(..$rest)).reduceK)" + q"Router($route -> NonEmptyList.of($first, ..$rest).reduceK)" } - val endpoints = (routes: List[Route]) => - routes.flatMap { route => - val name = Term.Name(route.name.last) - val endpointsName = Term.Select(Term.Name("endpoints"), name) - val controllersName = Term.Select(Term.Name("controller"), name) - val controllerContent = - if (route.params.length <= 1) Some(controllersName) - else Some(Term.Select(Term.Eta(controllersName), Term.Name("tupled"))) - controllerContent.map { content => - val toRoutes = Term.Apply(Term.Select(endpointsName, Term.Name("toRoutes")), List(content)) - q"val ${Pat.Var(name)} = $toRoutes" - } + val endpoints = (routes: List[TapiroRoute]) => + routes.map { route => + val name = Term.Name(route.route.name.last) + val endpointImpl = Meta.toEndpointImplementation(route) + q"val ${Pat.Var(name)} = endpoints.$name.toRoutes($endpointImpl)" } } diff --git a/tapiro/core/src/main/scala/io/buildo/tapiro/Meta.scala b/tapiro/core/src/main/scala/io/buildo/tapiro/Meta.scala index eae28e37..d5ce5340 100644 --- a/tapiro/core/src/main/scala/io/buildo/tapiro/Meta.scala +++ b/tapiro/core/src/main/scala/io/buildo/tapiro/Meta.scala @@ -3,44 +3,49 @@ package io.buildo.tapiro import io.buildo.metarpheus.core.intermediate.{TaggedUnion, Type => MetarpheusType} import scala.meta._ +import scala.meta.contrib._ import cats.data.NonEmptyList object Meta { val codecsImplicits = (routes: List[TapiroRoute]) => { - val jsonCodecs = (routes.flatMap { - case TapiroRoute(route, error) => - val params: List[MetarpheusType] = route.params.map(_.tpe) - ((if (route.method == "post") params else Nil) ++ - (error match { - case TapiroRouteError.OtherError(t) => List(t) - case _ => Nil - }) :+ - route.returns) - }.distinct - .filter(t => typeNameString(t) != "Unit") //no json codec for Unit in tapir - .filter(t => typeNameString(t) != "String") - .filter(t => typeNameString(t) != "AuthToken") - .map(toScalametaType) - ++ taggedUnionErrorMembers(routes)) - .map(t => t"JsonCodec[$t]") - val plainCodecs = routes.flatMap { - case TapiroRoute(route, _) => - (if (route.method == "get") route.params.map(_.tpe) else Nil) ++ - route.params.map(_.tpe).filter(typeNameString(_) == "AuthToken") - }.distinct.map(t => t"PlainCodec[${toScalametaType(t)}]") - val codecs = jsonCodecs ++ plainCodecs - codecs.zipWithIndex.map(toImplicitParam.tupled) + val notUnit = (t: MetarpheusType) => t != MetarpheusType.Name("Unit") + val toDecoder = (t: Type) => t"Decoder[$t]" + val toEncoder = (t: Type) => t"Encoder[$t]" + val toJsonCodec = (t: Type) => t"JsonCodec[$t]" + val toPlainCodec = (t: Type) => t"PlainCodec[$t]" + val routeRequiredImplicits = (route: TapiroRoute) => { + val (authParamTypes, nonAuthParamTypes) = + route.route.params.map(_.tpe).partition(isAuthToken) + val inputImplicits = + route.method match { + case RouteMethod.GET => + nonAuthParamTypes.map(toScalametaType).map(toPlainCodec) + case RouteMethod.POST => + nonAuthParamTypes.map(toScalametaType).flatMap(t => List(toDecoder(t), toEncoder(t))) + } + val outputImplicits = + List(route.route.returns).filter(notUnit).map(toScalametaType).map(toJsonCodec) + val errorImplicits = + route.error match { + case RouteError.TaggedUnionError(tu) => + tu.values.map(taggedUnionMemberType(tu)).map(toJsonCodec) + case RouteError.OtherError(t) => + List(t).filter(notUnit).map(toScalametaType).map(toJsonCodec) + } + val authImplicits = authParamTypes.map(toScalametaType).map(toPlainCodec) + inputImplicits ++ outputImplicits ++ errorImplicits ++ authImplicits + } + deduplicate(routes.flatMap(routeRequiredImplicits)).zipWithIndex.map(toImplicitParam.tupled) } - private[this] val taggedUnionErrorMembers = (routes: List[TapiroRoute]) => { - val taggedUnions = routes.collect { - case TapiroRoute(_, TapiroRouteError.TaggedUnionError(tu)) => tu - }.distinct - taggedUnions.flatMap { taggedUnion => - taggedUnion.values.map(taggedUnionMemberType(taggedUnion)) + private[this] val deduplicate: List[Type] => List[Type] = (ts: List[Type]) => + ts match { + case Nil => Nil + case head :: tail => head :: deduplicate(tail.filter(!_.isEqual(head))) } - } + + private[this] val isAuthToken = (t: MetarpheusType) => t == MetarpheusType.Name("AuthToken") private[this] val toImplicitParam = (paramType: Type, index: Int) => { val paramName = Term.Name(s"codec$index") @@ -71,4 +76,27 @@ object Meta { def packageFromList(`package`: NonEmptyList[String]): Term.Ref = `package`.tail .foldLeft[Term.Ref](Term.Name(`package`.head))((acc, n) => Term.Select(acc, Term.Name(n))) + + val toEndpointImplementation = (route: TapiroRoute) => { + val name = Term.Name(route.route.name.last) + val controllersName = q"controller.$name" + route.method match { + case RouteMethod.GET => + route.route.params.length match { + case 0 => q"_ => $controllersName()" + case 1 => controllersName + case _ => q"($controllersName _).tupled" + } + case RouteMethod.POST => + val fields = route.route.params + .filterNot(_.tpe == MetarpheusType.Name("AuthToken")) + .map(p => Term.Name(p.name.getOrElse(Meta.typeNameString(p.tpe)))) + val hasAuth = route.route.params + .exists(_.tpe == MetarpheusType.Name("AuthToken")) + if (hasAuth) + q"{ case (x, token) => $controllersName(..${fields.map(f => q"x.$f")}, token) }" + else + q"x => $controllersName(..${fields.map(f => q"x.$f")})" + } + } } diff --git a/tapiro/core/src/main/scala/io/buildo/tapiro/MetarpheusHelper.scala b/tapiro/core/src/main/scala/io/buildo/tapiro/MetarpheusHelper.scala index 8179f4fc..00141e2e 100644 --- a/tapiro/core/src/main/scala/io/buildo/tapiro/MetarpheusHelper.scala +++ b/tapiro/core/src/main/scala/io/buildo/tapiro/MetarpheusHelper.scala @@ -3,7 +3,18 @@ package io.buildo.tapiro import io.buildo.metarpheus.core.intermediate.{Type => MetarpheusType, Model, TaggedUnion, Route} object MetarpheusHelper { - def routeError(route: Route, models: List[Model]): TapiroRouteError = + def toTapiroRoute(models: List[Model])(route: Route): TapiroRoute = + TapiroRoute( + route = route, + method = route.method match { + case "get" => RouteMethod.GET + case "post" => RouteMethod.POST + case _ => throw new Exception("method not supported") + }, + error = routeError(route, models), + ) + + def routeError(route: Route, models: List[Model]): RouteError = route.error.map { error => val errorName = error match { case MetarpheusType.Name(name) => name @@ -16,7 +27,7 @@ object MetarpheusHelper { if (candidates.length > 1) throw new Exception(s"ambiguous error type name $errorName") else candidates.headOption - .map(TapiroRouteError.TaggedUnionError.apply) - .getOrElse(TapiroRouteError.OtherError(error)) - }.getOrElse(TapiroRouteError.OtherError(MetarpheusType.Name("String"))) + .map(RouteError.TaggedUnionError.apply) + .getOrElse(RouteError.OtherError(error)) + }.getOrElse(RouteError.OtherError(MetarpheusType.Name("String"))) } diff --git a/tapiro/core/src/main/scala/io/buildo/tapiro/TapirMeta.scala b/tapiro/core/src/main/scala/io/buildo/tapiro/TapirMeta.scala index deb04517..684e3190 100644 --- a/tapiro/core/src/main/scala/io/buildo/tapiro/TapirMeta.scala +++ b/tapiro/core/src/main/scala/io/buildo/tapiro/TapirMeta.scala @@ -1,6 +1,11 @@ package io.buildo.tapiro -import io.buildo.metarpheus.core.intermediate.{RouteParam, TaggedUnion, Type => MetarpheusType} +import io.buildo.metarpheus.core.intermediate.{ + Route, + RouteParam, + TaggedUnion, + Type => MetarpheusType, +} import scala.meta._ @@ -15,11 +20,16 @@ object TapirMeta { tapirEndpointsName: Term.Name, implicits: List[Term.Param], body: List[Defn.Val], + postInputClassDeclarations: List[Defn.Class], + postInputCodecDeclarations: List[Defn.Val], ) => q""" package ${`package`} { ..${imports.toList.map(i => q"import $i._")} + import io.circe.{ Decoder, Encoder } + import io.circe.generic.semiauto.{ deriveDecoder, deriveEncoder } import sttp.tapir._ + import sttp.tapir.json.circe._ import sttp.tapir.Codec.{ JsonCodec, PlainCodec } import sttp.model.StatusCode @@ -32,10 +42,13 @@ object TapirMeta { Type.Name(tapirEndpointsName.value), Name.Anonymous(), Nil, - )}[AuthToken] { ..${body.map( - d => d.copy(mods = mod"override" :: d.mods), - )} } + )}[AuthToken] { + ..${postInputCodecDeclarations} + ..${body.map(d => d.copy(mods = mod"override" :: d.mods))} + } } + + ..${postInputClassDeclarations} } """ @@ -44,33 +57,45 @@ object TapirMeta { private[this] val endpointType = (route: TapiroRoute) => { val returnType = toScalametaType(route.route.returns) - val argsList = route.route.params.map(p => toScalametaType(p.tpe)) - val argsType = argsList match { - case Nil => Type.Name("Unit") - case head :: Nil => head - case l => Type.Tuple(l) + val argsType = route.method match { + case RouteMethod.GET => + val argsList = route.route.params.map(p => toScalametaType(p.tpe)) + argsList match { + case Nil => Type.Name("Unit") + case head :: Nil => head + case l => Type.Tuple(l) + } + case RouteMethod.POST => + val authTokenType = route.route.params + .filter(_.tpe == MetarpheusType.Name(authTokenName)) + .map(t => toScalametaType(t.tpe)) + .headOption + val inputType = postInputType(route.route) + authTokenType match { + case Some(t) => Type.Tuple(List(inputType, t)) + case None => inputType + } } val error = toScalametaType(route.error match { - case TapiroRouteError.TaggedUnionError(t) => MetarpheusType.Name(t.name) - case TapiroRouteError.OtherError(t) => t + case RouteError.TaggedUnionError(t) => MetarpheusType.Name(t.name) + case RouteError.OtherError(t) => t }) t"Endpoint[$argsType, $error, $returnType, Nothing]" } private[this] val endpointImpl = (route: TapiroRoute) => { + val method = route.method match { + case RouteMethod.GET => "get" + case RouteMethod.POST => "post" + } val basicEndpoint = Term.Apply( Term - .Select(Term.Select(Term.Name("endpoint"), Term.Name(route.route.method)), Term.Name("in")), + .Select(Term.Select(Term.Name("endpoint"), Term.Name(method)), Term.Name("in")), List(Lit.String(route.route.name.tail.mkString)), ) - val (auth, params) = route.route.params.partition(_.tpe == MetarpheusType.Name(authTokenName)) - val endpointsWithParams = withParams(basicEndpoint, route.route.method, params) withOutput( withError( - auth match { - case Nil => endpointsWithParams - case _ => withAuth(endpointsWithParams) - }, + withParams(basicEndpoint, route), route.error, ), route.route.returns, @@ -91,45 +116,47 @@ object TapirMeta { ), ) - private[this] val withParams = - (endpoint: meta.Term, method: String, params: List[RouteParam]) => { - method match { - case "get" => - params.foldLeft(endpoint) { (acc, param) => - withParam(acc, param) - } - case "post" => - params.foldLeft(endpoint) { (acc, param) => - withBody(acc, param.tpe) - } - case _ => throw new Exception("method not supported") - }, + private[this] val withParams = (endpoint: meta.Term, route: TapiroRoute) => { + val (auth, params) = route.route.params.partition(_.tpe == MetarpheusType.Name(authTokenName)) + val endpointWithParams = route.method match { + case RouteMethod.GET => + params.foldLeft(endpoint) { (acc, param) => + withParam(acc, param) + } + case RouteMethod.POST => + withBody(endpoint, route.route) + } + auth match { + case Nil => endpointWithParams + case _ => withAuth(endpointWithParams) } + } - private[this] val withBody = (endpoint: meta.Term, tpe: MetarpheusType) => { + private[this] val withBody = (endpoint: meta.Term, route: Route) => { Term.Apply( Term.Select(endpoint, Term.Name("in")), - List(Term.ApplyType(Term.Name("jsonBody"), List(toScalametaType(tpe)))), + List(Term.ApplyType(Term.Name("jsonBody"), List(postInputType(route)))), ), } private[this] val withError = - (endpoints: meta.Term, routeError: TapiroRouteError) => + (endpoints: meta.Term, routeError: RouteError) => routeError match { - case TapiroRouteError.OtherError(t) if typeNameString(t) == "Unit" => endpoints - case _ => Term.Apply( - Term.Select(endpoints, Term.Name("errorOut")), - List( - routeError match { - case TapiroRouteError.TaggedUnionError(taggedUnion) => - listErrors(taggedUnion) - case TapiroRouteError.OtherError(MetarpheusType.Name("String")) => - Term.Name("stringBody") - case TapiroRouteError.OtherError(t) => - Term.ApplyType(Term.Name("jsonBody"), List(toScalametaType(t))) - }, - ), - ) + case RouteError.OtherError(t) if typeNameString(t) == "Unit" => endpoints + case _ => + Term.Apply( + Term.Select(endpoints, Term.Name("errorOut")), + List( + routeError match { + case RouteError.TaggedUnionError(taggedUnion) => + listErrors(taggedUnion) + case RouteError.OtherError(MetarpheusType.Name("String")) => + Term.Name("stringBody") + case RouteError.OtherError(t) => + Term.ApplyType(Term.Name("jsonBody"), List(toScalametaType(t))) + }, + ), + ) } private[this] val listErrors = (taggedUnion: TaggedUnion) => @@ -150,11 +177,6 @@ object TapirMeta { typeNameString(returnType) match { case "Unit" => endpoint - case "String" => - Term.Apply( - Term.Select(endpoint, Term.Name("out")), - List(Term.Name("stringBody")), - ) case _ => Term.Apply( Term.Select(endpoint, Term.Name("out")), @@ -185,4 +207,35 @@ object TapirMeta { Term.Apply(Term.Select(noDesc, Term.Name("description")), List(Lit.String(desc))) } } + + private[this] val postInputType = (route: Route) => + Type.Name(route.name.tail.mkString.capitalize + "RequestPayload") + + val routeClassDeclarations = (route: TapiroRoute) => + route.method match { + case RouteMethod.POST => + val params = route.route.params + .filterNot(_.tpe == MetarpheusType.Name(authTokenName)) + .map { p => + param"${Term.Name(p.name.getOrElse(typeNameString(p.tpe)))}: ${toScalametaType(p.tpe)}" + } + List(q"case class ${postInputType(route.route)}(..$params)") + case RouteMethod.GET => + Nil + } + + val routeCodecDeclarations = (route: TapiroRoute) => { + val mkDeclaration = (s: String) => { + val name = Pat.Var(Term.Name(route.route.name.tail.mkString + "RequestPayload" + s)) + val tpe = postInputType(route.route) + val fun = Term.Name("derive" + s) + q"implicit val $name : ${Type.Name(s)}[$tpe] = $fun" + } + route.method match { + case RouteMethod.POST => + List("Decoder", "Encoder").map(mkDeclaration) + case RouteMethod.GET => + Nil + } + } } diff --git a/tapiro/core/src/main/scala/io/buildo/tapiro/Util.scala b/tapiro/core/src/main/scala/io/buildo/tapiro/Util.scala index aceb3fd5..57e4ad7e 100644 --- a/tapiro/core/src/main/scala/io/buildo/tapiro/Util.scala +++ b/tapiro/core/src/main/scala/io/buildo/tapiro/Util.scala @@ -25,13 +25,19 @@ object Server { case object NoServer extends Server } -sealed trait TapiroRouteError -object TapiroRouteError { - case class TaggedUnionError(taggedUnion: TaggedUnion) extends TapiroRouteError - case class OtherError(`type`: MetarpheusType) extends TapiroRouteError +sealed trait RouteError +object RouteError { + case class TaggedUnionError(taggedUnion: TaggedUnion) extends RouteError + case class OtherError(`type`: MetarpheusType) extends RouteError } -case class TapiroRoute(route: Route, error: TapiroRouteError) +sealed trait RouteMethod +object RouteMethod { + case object GET extends RouteMethod + case object POST extends RouteMethod +} + +case class TapiroRoute(route: Route, method: RouteMethod, error: RouteError) class Util() { import Formatter.format @@ -49,9 +55,8 @@ class Util() { case Some(nonEmptyPackage) => val config = Config(Set.empty) val models = Metarpheus.run(modelsPaths, config).models - val routes: List[TapiroRoute] = Metarpheus.run(routesPaths, config).routes.map { route => - TapiroRoute(route, routeError(route, models)) - } + val routes: List[TapiroRoute] = + Metarpheus.run(routesPaths, config).routes.map(toTapiroRoute(models)) val controllersRoutes = routes.groupBy( route => (route.route.controllerType, route.route.pathName), @@ -125,6 +130,8 @@ class Util() { Term.Name(tapirEndpointsName), Meta.codecsImplicits(routes), routes.map(TapirMeta.routeToTapirEndpoint), + routes.flatMap(TapirMeta.routeClassDeclarations), + routes.flatMap(TapirMeta.routeCodecDeclarations), ), ) } @@ -151,7 +158,7 @@ class Util() { Term.Name(tapirEndpointsName), Term.Name(httpEndpointsName), Meta.codecsImplicits(tapiroRoutes) :+ param"implicit cs: ContextShift[F]", - Http4sMeta.endpoints(routes), + Http4sMeta.endpoints(tapiroRoutes), Http4sMeta.routes(Lit.String(pathName), head, tail), ), ), @@ -181,7 +188,7 @@ class Util() { Term.Name(tapirEndpointsName), Term.Name(httpEndpointsName), Meta.codecsImplicits(tapiroRoutes), - AkkaHttpMeta.endpoints(routes), + AkkaHttpMeta.endpoints(tapiroRoutes), AkkaHttpMeta.routes(Lit.String(pathName), head, tail), ), ), diff --git a/tapiro/core/src/test/scala/io/buildo/tapiro/TapiroSuite.scala b/tapiro/core/src/test/scala/io/buildo/tapiro/TapiroSuite.scala index 796a7b51..e61dceb7 100644 --- a/tapiro/core/src/test/scala/io/buildo/tapiro/TapiroSuite.scala +++ b/tapiro/core/src/test/scala/io/buildo/tapiro/TapiroSuite.scala @@ -5,7 +5,7 @@ import java.nio.file.Files class TapiroSuite extends munit.FunSuite { check( - "http4s", + "tapir-http4s-endpoints", Server.Http4s, "src/main/scala/schools/endpoints", """ @@ -14,14 +14,23 @@ class TapiroSuite extends munit.FunSuite { | |case class School(id: Long, name: String) | + |sealed trait SchoolCreateError + |object SchoolCreateError { + | case object DuplicateId extends SchoolCreateError + |} + | |sealed trait SchoolReadError |object SchoolReadError { | case object NotFound extends SchoolReadError |} | - |trait SchoolController[F[_], T] { + |trait SchoolController[F[_], AuthToken] { + | @command + | def create(school: School, token: AuthToken): F[Either[SchoolCreateError, Unit]] | @query | def read(id: Long): F[Either[SchoolReadError, School]] + | @query + | def list(): F[Either[Unit, List[School]]] |} |""".stripMargin, """ @@ -34,21 +43,58 @@ class TapiroSuite extends munit.FunSuite { | |package endpoints |import schools._ + |import io.circe.{Decoder, Encoder} + |import io.circe.generic.semiauto.{deriveDecoder, deriveEncoder} |import sttp.tapir._ + |import sttp.tapir.json.circe._ |import sttp.tapir.Codec.{JsonCodec, PlainCodec} |import sttp.model.StatusCode | |trait SchoolControllerTapirEndpoints[AuthToken] { + | + | val create: Endpoint[ + | (CreateRequestPayload, AuthToken), + | SchoolCreateError, + | Unit, + | Nothing + | ] | val read: Endpoint[Long, SchoolReadError, School, Nothing] + | val list: Endpoint[Unit, Unit, List[School], Nothing] |} | |object SchoolControllerTapirEndpoints { | | def create[AuthToken](statusCodes: String => StatusCode)( - | implicit codec0: JsonCodec[School], - | codec1: JsonCodec[SchoolReadError.NotFound.type], - | codec2: PlainCodec[Long] + | implicit codec0: Decoder[School], + | codec1: Encoder[School], + | codec2: JsonCodec[SchoolCreateError.DuplicateId.type], + | codec3: PlainCodec[AuthToken], + | codec4: PlainCodec[Long], + | codec5: JsonCodec[School], + | codec6: JsonCodec[SchoolReadError.NotFound.type], + | codec7: JsonCodec[List[School]] | ) = new SchoolControllerTapirEndpoints[AuthToken] { + | implicit val createRequestPayloadDecoder: Decoder[CreateRequestPayload] = + | deriveDecoder + | implicit val createRequestPayloadEncoder: Encoder[CreateRequestPayload] = + | deriveEncoder + | override val create: Endpoint[ + | (CreateRequestPayload, AuthToken), + | SchoolCreateError, + | Unit, + | Nothing + | ] = endpoint.post + | .in("create") + | .in(jsonBody[CreateRequestPayload]) + | .in(header[AuthToken]("Authorization")) + | .errorOut( + | oneOf[SchoolCreateError]( + | statusMapping( + | statusCodes("DuplicateId"), + | jsonBody[SchoolCreateError.DuplicateId.type] + | ) + | ) + | ) | override val read: Endpoint[Long, SchoolReadError, School, Nothing] = | endpoint.get | .in("read") @@ -62,8 +108,11 @@ class TapiroSuite extends munit.FunSuite { | ) | ) | .out(jsonBody[School]) + | override val list: Endpoint[Unit, Unit, List[School], Nothing] = + | endpoint.get.in("list").out(jsonBody[List[School]]) | } |} + |case class CreateRequestPayload(school: School) | |/src/main/scala/schools/endpoints/SchoolControllerHttpEndpoints.scala |//---------------------------------------------------------- @@ -77,6 +126,7 @@ class TapiroSuite extends munit.FunSuite { |import cats.effect._ |import cats.implicits._ |import cats.data.NonEmptyList + |import io.circe.{Decoder, Encoder} |import org.http4s._ |import org.http4s.server.Router |import sttp.tapir.server.http4s._ @@ -89,15 +139,102 @@ class TapiroSuite extends munit.FunSuite { | controller: SchoolController[F, AuthToken], | statusCodes: String => StatusCode = _ => StatusCode.UnprocessableEntity | )( - | implicit codec0: JsonCodec[School], - | codec1: JsonCodec[SchoolReadError.NotFound.type], - | codec2: PlainCodec[Long], + | implicit codec0: Decoder[School], + | codec1: Encoder[School], + | codec2: JsonCodec[SchoolCreateError.DuplicateId.type], + | codec3: PlainCodec[AuthToken], + | codec4: PlainCodec[Long], + | codec5: JsonCodec[School], + | codec6: JsonCodec[SchoolReadError.NotFound.type], + | codec7: JsonCodec[List[School]], | cs: ContextShift[F] | ): HttpRoutes[F] = { | val endpoints = | SchoolControllerTapirEndpoints.create[AuthToken](statusCodes) + | val create = endpoints.create.toRoutes({ + | case (x, token) => + | controller.create(x.school, token) + | }) | val read = endpoints.read.toRoutes(controller.read) - | Router("/SchoolController" -> NonEmptyList(read, List()).reduceK) + | val list = endpoints.list.toRoutes(_ => controller.list()) + | Router("/SchoolController" -> NonEmptyList.of(create, read, list).reduceK) + | } + |} + |""".stripMargin, + ) + + check( + "akkaHttp-endpoints", + Server.AkkaHttp, + "src/main/scala/schools/endpoints", + """ + |/src/main/scala/schools/SchoolController.scala + |package schools + | + |case class School(id: Long, name: String) + | + |sealed trait SchoolCreateError + |object SchoolCreateError { + | case object DuplicateId extends SchoolCreateError + |} + | + |sealed trait SchoolReadError + |object SchoolReadError { + | case object NotFound extends SchoolReadError + |} + | + |trait SchoolController[AuthToken] { + | @command + | def create(school: School, token: AuthToken): Future[Either[SchoolCreateError, Unit]] + | @query + | def read(id: Long): Future[Either[SchoolReadError, School]] + | @query + | def list(): F[Either[Unit, List[School]]] + |} + |""".stripMargin, + """ + |/src/main/scala/schools/endpoints/SchoolControllerHttpEndpoints.scala + |//---------------------------------------------------------- + |// This code was generated by tapiro. + |// Changes to this file may cause incorrect behavior + |// and will be lost if the code is regenerated. + |//---------------------------------------------------------- + | + |package endpoints + |import schools._ + |import akka.http.scaladsl.server._ + |import akka.http.scaladsl.server.Directives._ + |import io.circe.{Decoder, Encoder} + |import sttp.tapir.server.akkahttp._ + |import sttp.tapir.Codec.{JsonCodec, PlainCodec} + |import sttp.model.StatusCode + | + |object SchoolControllerHttpEndpoints { + | + | def routes[AuthToken]( + | controller: SchoolController[AuthToken], + | statusCodes: String => StatusCode = _ => StatusCode.UnprocessableEntity + | )( + | implicit codec0: Decoder[School], + | codec1: Encoder[School], + | codec2: JsonCodec[SchoolCreateError.DuplicateId.type], + | codec3: PlainCodec[AuthToken], + | codec4: PlainCodec[Long], + | codec5: JsonCodec[School], + | codec6: JsonCodec[SchoolReadError.NotFound.type], + | codec7: JsonCodec[List[School]] + | ): Route = { + | val endpoints = + | SchoolControllerTapirEndpoints.create[AuthToken](statusCodes) + | val create = endpoints.create.toRoute({ + | case (x, token) => + | controller.create(x.school, token) + | }) + | val read = endpoints.read.toRoute(controller.read) + | val list = endpoints.list.toRoute(_ => controller.list()) + | pathPrefix("SchoolController") { + | List(read, list).foldLeft[Route](create)(_ ~ _) + | } | } |} |""".stripMargin, diff --git a/tapiro/sbt-tapiro/src/sbt-test/sbt-tapiro/simple/test b/tapiro/sbt-tapiro/src/sbt-test/sbt-tapiro/simple/test index 19af1bc7..afd16f56 100644 --- a/tapiro/sbt-tapiro/src/sbt-test/sbt-tapiro/simple/test +++ b/tapiro/sbt-tapiro/src/sbt-test/sbt-tapiro/simple/test @@ -10,4 +10,4 @@ $ exists "src/main/scala/endpoints/ExampleControllerHttpEndpoints.scala" # check that the endpoints respond as expected # NOTE(gabro): the single quotes surrounding the commands are a workaround for https://github.com/sbt/sbt/issues/4870 > 'curlExpect -s -X GET localhost:8080/ExampleController/queryExample?intParam=1&stringParam=abc {"name":"abc","double":1.0}' -> 'curlExpect -s -X POST localhost:8080/ExampleController/commandExample -d "{\"name\":\"abc\",\"double\":1.0}" abc' +> 'curlExpect -s -X POST localhost:8080/ExampleController/commandExample -d "{\"body\": {\"name\":\"abc\",\"double\":1.0}}" "\"abc\""'