diff --git a/build.sbt b/build.sbt index ee019136..6f74ae06 100644 --- a/build.sbt +++ b/build.sbt @@ -4,13 +4,13 @@ import java.nio.file.StandardCopyOption.REPLACE_EXISTING ThisBuild / scalaVersion := "3.3.0" ThisBuild / organization := "net.wiringbits" -val playJson = "2.10.0-RC9" +val playJson = "3.0.1" val sttp = "3.8.15" val webappUtils = "0.7.2" val anorm = "2.7.0" val enumeratum = "1.7.2" val scalaJavaTime = "2.5.0" -val tapir = "1.5.0" +val tapir = "1.8.5" val chimney = "0.8.0-RC1" val consoleDisabledOptions = Seq("-Werror", "-Ywarn-unused", "-Ywarn-unused-import") @@ -226,7 +226,7 @@ lazy val common = (crossProject(JSPlatform, JVMPlatform) in file("lib/common")) ) .jvmSettings( libraryDependencies ++= Seq( - "com.typesafe.play" %% "play-json" % playJson, + "org.playframework" %% "play-json" % playJson, "net.wiringbits" %% "webapp-common" % webappUtils, "org.scalatest" %% "scalatest" % "3.2.16" % Test ) @@ -238,7 +238,7 @@ lazy val common = (crossProject(JSPlatform, JVMPlatform) in file("lib/common")) Compile / stMinimize := Selection.All, libraryDependencies ++= Seq( "io.github.cquiroz" %%% "scala-java-time" % scalaJavaTime, - "com.typesafe.play" %%% "play-json" % playJson, + "org.playframework" %%% "play-json" % playJson, "net.wiringbits" %%% "webapp-common" % webappUtils, "org.scalatest" %%% "scalatest" % "3.2.16" % Test, "com.beachape" %%% "enumeratum" % enumeratum @@ -247,14 +247,14 @@ lazy val common = (crossProject(JSPlatform, JVMPlatform) in file("lib/common")) // shared apis lazy val api = (crossProject(JSPlatform, JVMPlatform) in file("lib/api")) - .dependsOn(common, tapirPlayJson) + .dependsOn(common) .configure(baseLibSettings, commonSettings) .jsConfigure(_.enablePlugins(ScalaJSPlugin, ScalaJSBundlerPlugin, ScalablyTypedConverterPlugin)) .jvmSettings( libraryDependencies ++= Seq( - "com.typesafe.play" %% "play-json" % playJson, + "org.playframework" %% "play-json" % playJson, "com.softwaremill.sttp.client3" %% "core" % sttp, - "com.softwaremill.sttp.tapir" %% "tapir-core" % tapir, + "com.softwaremill.sttp.tapir" %% "tapir-json-play" % tapir, "com.softwaremill.sttp.tapir" %% "tapir-sttp-client" % tapir ) ) @@ -264,11 +264,11 @@ lazy val api = (crossProject(JSPlatform, JVMPlatform) in file("lib/api")) stUseScalaJsDom := true, Compile / stMinimize := Selection.All, libraryDependencies ++= Seq( - "com.typesafe.play" %%% "play-json" % playJson, - "com.softwaremill.sttp.client3" %%% "core" % sttp, + "org.playframework" %%% "play-json" % playJson, "org.scalatest" %%% "scalatest" % "3.2.16" % Test, "com.beachape" %%% "enumeratum" % enumeratum, - "com.softwaremill.sttp.tapir" %%% "tapir-core" % tapir, + "com.softwaremill.sttp.client3" %%% "core" % sttp, + "com.softwaremill.sttp.tapir" %%% "tapir-json-play" % tapir, "com.softwaremill.sttp.tapir" %%% "tapir-sttp-client" % tapir ) ) @@ -317,37 +317,8 @@ lazy val ui = (project in file("lib/ui")) ) ) -lazy val tapirServerCore = (project in file("tapir/core")) - .settings( - name := "tapir-server-core", - libraryDependencies += "com.softwaremill.sttp.tapir" %% "tapir-core" % tapir - ) - -lazy val tapirServerPlay = (project in file("tapir/tapir-play")) - .settings( - name := "tapir-server-play", - libraryDependencies ++= Seq( - "com.typesafe.play" %% "play-akka-http-server" % "2.9.0-M6", - "com.softwaremill.sttp.shared" %% "akka" % "1.3.14", - "org.scala-lang.modules" %% "scala-collection-compat" % "2.11.0" - ) - ) - .dependsOn(tapirServerCore) - -lazy val tapirPlayJson = (crossProject(JSPlatform, JVMPlatform) in file("tapir/playjson")) - .settings( - name := "tapir-play-json", - libraryDependencies ++= Seq( - "com.typesafe.play" %%% "play-json" % playJson, - "com.softwaremill.sttp.tapir" %% "tapir-core" % tapir - ) - ) - .jsSettings( - libraryDependencies += "io.github.cquiroz" %%% "scala-java-time" % scalaJavaTime - ) - lazy val server = (project in file("server")) - .dependsOn(common.jvm, api.jvm, tapirServerPlay, tapirPlayJson.jvm) + .dependsOn(common.jvm, api.jvm) .configure(baseServerSettings, commonSettings, playSettings) .settings( name := "wiringbits-server", @@ -355,9 +326,8 @@ lazy val server = (project in file("server")) Test / fork := true, // allows for graceful shutdown of containers once the tests have finished running libraryDependencies ++= Seq( "org.playframework.anorm" %% "anorm" % anorm, - "org.playframework.anorm" %% "anorm-akka" % anorm, "org.playframework.anorm" %% "anorm-postgres" % anorm, - "com.typesafe.play" %% "play-json" % playJson, + "org.playframework" %% "play-json" % playJson, "org.postgresql" % "postgresql" % "42.6.0", "de.svenkubiak" % "jBCrypt" % "0.4.3", "commons-validator" % "commons-validator" % "1.7", @@ -375,8 +345,11 @@ lazy val server = (project in file("server")) "javax.el" % "javax.el-api" % "3.0.0", "org.glassfish" % "javax.el" % "3.0.0", "com.beachape" %% "enumeratum" % enumeratum, + "io.scalaland" %% "chimney" % chimney, "com.softwaremill.sttp.tapir" %% "tapir-swagger-ui-bundle" % tapir, - "io.scalaland" %% "chimney" % chimney + "com.softwaremill.sttp.tapir" %% "tapir-json-play" % tapir, + "com.softwaremill.sttp.tapir" %% "tapir-play-server" % tapir, + "org.apache.pekko" %% "pekko-stream" % "1.0.1" ) ) @@ -431,7 +404,7 @@ lazy val web = (project in file("web")) "@types/react-google-recaptcha" -> "2.1.0" ), libraryDependencies ++= Seq( - "com.typesafe.play" %%% "play-json" % playJson, + "org.playframework" %%% "play-json" % playJson, "com.softwaremill.sttp.client3" %%% "core" % sttp, "org.scala-js" %%% "scala-js-macrotask-executor" % "1.1.1", "com.olvind.st-material-ui" %%% "st-material-ui-icons-slinky" % "5.11.16", diff --git a/project/plugins.sbt b/project/plugins.sbt index 0ff63251..7bee1182 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -3,7 +3,7 @@ evictionErrorLevel := sbt.util.Level.Warn addSbtPlugin("org.portable-scala" % "sbt-scalajs-crossproject" % "1.3.1") -addSbtPlugin("com.typesafe.play" % "sbt-plugin" % "2.9.0-M6") +addSbtPlugin("org.playframework" % "sbt-plugin" % "3.0.0") addSbtPlugin("org.scala-js" % "sbt-scalajs" % "1.13.1") diff --git a/server/src/main/scala/PekkoStream.scala b/server/src/main/scala/PekkoStream.scala new file mode 100644 index 00000000..d7746573 --- /dev/null +++ b/server/src/main/scala/PekkoStream.scala @@ -0,0 +1,201 @@ +/* + * Copyright (C) from 2022 The Play Framework Contributors , 2011-2021 Lightbend Inc. + */ + +package anorm + +import java.sql.Connection +import scala.util.control.NonFatal +import scala.concurrent.{Future, Promise} +import org.apache.pekko.stream.scaladsl.Source + +import scala.annotation.nowarn + +/** Anorm companion for the Pekko Streams. + * + * @define materialization + * It materializes a [[scala.concurrent.Future]] of [[scala.Int]] containing the number of rows read from the source + * upon completion, and a possible exception if row parsing failed. + * @define sqlParam + * the SQL query + * @define connectionParam + * the JDBC connection, which must not be closed until the source is materialized. + * @define columnAliaserParam + * the column aliaser + */ +// From https://github.com/playframework/anorm/blob/main/pekko/src/main/scala/anorm/PekkoStream.scala +// We are copying this because the anorm.pekko isn't published yet +// TODO: remove after anorm.pekko is published +object PekkoStream { + + /** Returns the rows parsed from the `sql` query as a reactive source. + * + * $materialization + * + * @tparam T + * the type of the result elements + * @param sql + * $sqlParam + * @param parser + * the result (row) parser + * @param as + * $columnAliaserParam + * @param connection + * $connectionParam + * + * {{{ + * import java.sql.Connection + * + * import scala.concurrent.Future + * + * import org.apache.pekko.stream.scaladsl.Source + * + * import anorm._ + * + * def resultSource(implicit con: Connection): Source[String, Future[Int]] = PekkoStream.source(SQL"SELECT * FROM Test", SqlParser.scalar[String], ColumnAliaser.empty) + * }}} + */ + @SuppressWarnings(Array("UnusedMethodParameter")) + def source[T](sql: => Sql, parser: RowParser[T], as: ColumnAliaser)(implicit + con: Connection + ): Source[T, Future[Int]] = Source.fromGraph(new ResultSource[T](con, sql, as, parser)) + + /** Returns the rows parsed from the `sql` query as a reactive source. + * + * $materialization + * + * @tparam T + * the type of the result elements + * @param sql + * $sqlParam + * @param parser + * the result (row) parser + * @param connection + * $connectionParam + */ + @SuppressWarnings(Array("UnusedMethodParameter")) + def source[T](sql: => Sql, parser: RowParser[T])(implicit con: Connection): Source[T, Future[Int]] = + source[T](sql, parser, ColumnAliaser.empty) + + /** Returns the result rows from the `sql` query as an enumerator. This is equivalent to `source[Row](sql, + * RowParser.successful, as)`. + * + * $materialization + * + * @param sql + * $sqlParam + * @param as + * $columnAliaserParam + * @param connection + * $connectionParam + */ + def source(sql: => Sql, as: ColumnAliaser)(implicit connection: Connection): Source[Row, Future[Int]] = + source(sql, RowParser.successful, as) + + /** Returns the result rows from the `sql` query as an enumerator. This is equivalent to `source[Row](sql, + * RowParser.successful, ColumnAliaser.empty)`. + * + * $materialization + * + * @param sql + * $sqlParam + * @param connection + * $connectionParam + */ + def source(sql: => Sql)(implicit connnection: Connection): Source[Row, Future[Int]] = + source(sql, RowParser.successful, ColumnAliaser.empty) + + // Internal stages + + import org.apache.pekko.stream.stage.{GraphStageLogic, GraphStageWithMaterializedValue, OutHandler} + import org.apache.pekko.stream.{Attributes, Outlet, SourceShape} + + import java.sql.ResultSet + import scala.util.{Failure, Success} + + private[anorm] class ResultSource[T](connection: Connection, sql: Sql, as: ColumnAliaser, parser: RowParser[T]) + extends GraphStageWithMaterializedValue[SourceShape[T], Future[Int]] { + + @SuppressWarnings(Array("org.wartremover.warts.Null")) + private[anorm] var resultSet: ResultSet = _ + + override val toString = "AnormQueryResult" + val out: Outlet[T] = Outlet(s"${toString}.out") + val shape: SourceShape[T] = SourceShape(out) + + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, Future[Int]) = { + val result = Promise[Int]() + + val logic = new GraphStageLogic(shape) with OutHandler { + private var cursor: Option[Cursor] = None + private var counter: Int = 0 + + private def failWith(cause: Throwable): Unit = { + result.failure(cause) + fail(out, cause) + () + } + + override def preStart(): Unit = { + try { + resultSet = sql.unsafeResultSet(connection) + nextCursor() + } catch { + case NonFatal(cause) => failWith(cause) + } + } + + override def postStop() = release() + + private def release(): Unit = { + val stmt: Option[java.sql.Statement] = { + if (resultSet != null && !resultSet.isClosed) { + val s = resultSet.getStatement + resultSet.close() + Option(s) + } else None + } + + stmt.foreach { s => + if (!s.isClosed) s.close() + } + } + + private def nextCursor(): Unit = { + cursor = Sql.unsafeCursor(resultSet, sql.resultSetOnFirstRow, as) + } + + def onPull(): Unit = cursor match { + case Some(c) => + c.row.as(parser) match { + case Success(parsed) => { + counter += 1 + push(out, parsed) + nextCursor() + } + + case Failure(cause) => + failWith(cause) + } + + case _ => { + result.success(counter) + complete(out) + } + } + + @nowarn + override def onDownstreamFinish() = { + result.tryFailure(new InterruptedException("Downstream finished")) + release() + super.onDownstreamFinish() + } + + setHandler(out, this) + } + + logic -> result.future + } + } + +} diff --git a/server/src/main/scala/controllers/AdminController.scala b/server/src/main/scala/controllers/AdminController.scala index f2e39d8e..2577bae4 100644 --- a/server/src/main/scala/controllers/AdminController.scala +++ b/server/src/main/scala/controllers/AdminController.scala @@ -7,7 +7,7 @@ import net.wiringbits.common.models.Email import net.wiringbits.services.AdminService import org.slf4j.LoggerFactory import sttp.capabilities.WebSockets -import sttp.capabilities.akka.AkkaStreams +import sttp.capabilities.pekko.PekkoStreams import sttp.tapir.server.ServerEndpoint import java.util.UUID @@ -42,7 +42,7 @@ class AdminController @Inject() ( } yield Right(maskedResponse) } - def routes: List[ServerEndpoint[AkkaStreams with WebSockets, Future]] = { + def routes: List[ServerEndpoint[PekkoStreams with WebSockets, Future]] = { List( AdminEndpoints.getUserLogsEndpoint.serverLogic(getUserLogs), AdminEndpoints.getUsersEndpoint.serverLogic(getUsers) diff --git a/server/src/main/scala/controllers/ApiRouter.scala b/server/src/main/scala/controllers/ApiRouter.scala index 6cd7a87b..9e10aae7 100644 --- a/server/src/main/scala/controllers/ApiRouter.scala +++ b/server/src/main/scala/controllers/ApiRouter.scala @@ -1,8 +1,9 @@ package controllers -import akka.stream.Materializer import net.wiringbits.api.endpoints.* import net.wiringbits.config.SwaggerConfig +import org.apache.pekko.actor.ActorSystem +import org.apache.pekko.stream.Materializer import play.api.routing.Router.Routes import play.api.routing.SimpleRouter import sttp.apispec.openapi.Info @@ -21,8 +22,10 @@ class ApiRouter @Inject() ( usersController: UsersController, environmentConfigController: EnvironmentConfigController, swaggerConfig: SwaggerConfig -)(implicit materializer: Materializer, ec: ExecutionContext) +)(using ExecutionContext) extends SimpleRouter { + given ActorSystem = ActorSystem("ApiRouter") + private val swagger = SwaggerInterpreter( swaggerUIOptions = SwaggerUIOptions.default.copy(contextPath = List(swaggerConfig.basePath)) ) diff --git a/server/src/main/scala/controllers/AuthController.scala b/server/src/main/scala/controllers/AuthController.scala index 723ae111..d2d12521 100644 --- a/server/src/main/scala/controllers/AuthController.scala +++ b/server/src/main/scala/controllers/AuthController.scala @@ -6,7 +6,7 @@ import net.wiringbits.api.models.* import net.wiringbits.api.models.auth.{GetCurrentUser, Login, Logout} import org.slf4j.LoggerFactory import sttp.capabilities.WebSockets -import sttp.capabilities.akka.AkkaStreams +import sttp.capabilities.pekko.PekkoStreams import sttp.tapir.server.ServerEndpoint import java.util.UUID @@ -47,7 +47,7 @@ class AuthController @Inject() ( } yield Right(Logout.Response(), header) } - def routes: List[ServerEndpoint[AkkaStreams with WebSockets, Future]] = { + def routes: List[ServerEndpoint[PekkoStreams with WebSockets, Future]] = { List( AuthEndpoints.login.serverLogic(login), AuthEndpoints.getCurrentUser.serverLogic(me), diff --git a/server/src/main/scala/controllers/EnvironmentConfigController.scala b/server/src/main/scala/controllers/EnvironmentConfigController.scala index 25ffe4c1..a131cfd6 100644 --- a/server/src/main/scala/controllers/EnvironmentConfigController.scala +++ b/server/src/main/scala/controllers/EnvironmentConfigController.scala @@ -6,7 +6,7 @@ import net.wiringbits.api.models.ErrorResponse import net.wiringbits.api.models.environmentconfig.GetEnvironmentConfig import org.slf4j.LoggerFactory import sttp.capabilities.WebSockets -import sttp.capabilities.akka.AkkaStreams +import sttp.capabilities.pekko.PekkoStreams import sttp.tapir.server.ServerEndpoint import javax.inject.Inject @@ -24,7 +24,7 @@ class EnvironmentConfigController @Inject() ( } yield Right(response) } - def routes: List[ServerEndpoint[AkkaStreams with WebSockets, Future]] = { + def routes: List[ServerEndpoint[PekkoStreams with WebSockets, Future]] = { List(EnvironmentConfigEndpoints.getEnvironmentConfig.serverLogic(_ => getEnvironmentConfig)) } } diff --git a/server/src/main/scala/controllers/HealthController.scala b/server/src/main/scala/controllers/HealthController.scala index 46f3727e..2ba25546 100644 --- a/server/src/main/scala/controllers/HealthController.scala +++ b/server/src/main/scala/controllers/HealthController.scala @@ -2,7 +2,7 @@ package controllers import net.wiringbits.api.endpoints.HealthEndpoints import sttp.capabilities.WebSockets -import sttp.capabilities.akka.AkkaStreams +import sttp.capabilities.pekko.PekkoStreams import sttp.model.headers.{Cookie, CookieValueWithMeta, CookieWithMeta} import sttp.tapir.server.ServerEndpoint @@ -15,7 +15,7 @@ class HealthController @Inject() (implicit ec: ExecutionContext) { private def check: Future[Either[Unit, Unit]] = Future.successful(Right(())) - def routes: List[ServerEndpoint[AkkaStreams with WebSockets, Future]] = { + def routes: List[ServerEndpoint[PekkoStreams with WebSockets, Future]] = { List(HealthEndpoints.check.serverLogic(_ => check)) } } diff --git a/server/src/main/scala/controllers/UsersController.scala b/server/src/main/scala/controllers/UsersController.scala index 15098f5a..ebb5386a 100644 --- a/server/src/main/scala/controllers/UsersController.scala +++ b/server/src/main/scala/controllers/UsersController.scala @@ -7,7 +7,7 @@ import net.wiringbits.api.models.* import net.wiringbits.api.models.users.* import org.slf4j.LoggerFactory import sttp.capabilities.WebSockets -import sttp.capabilities.akka.AkkaStreams +import sttp.capabilities.pekko.PekkoStreams import sttp.tapir.server.ServerEndpoint import java.util.UUID @@ -100,7 +100,7 @@ class UsersController @Inject() ( } yield Right(response) } - def routes: List[ServerEndpoint[AkkaStreams with WebSockets, Future]] = { + def routes: List[ServerEndpoint[PekkoStreams with WebSockets, Future]] = { List( UsersEndpoints.create.serverLogic(create), UsersEndpoints.verifyEmail.serverLogic(verifyEmail), diff --git a/server/src/main/scala/net/wiringbits/actions/internal/StreamPendingBackgroundJobsForeverAction.scala b/server/src/main/scala/net/wiringbits/actions/internal/StreamPendingBackgroundJobsForeverAction.scala index 563b29c5..6926f0ba 100644 --- a/server/src/main/scala/net/wiringbits/actions/internal/StreamPendingBackgroundJobsForeverAction.scala +++ b/server/src/main/scala/net/wiringbits/actions/internal/StreamPendingBackgroundJobsForeverAction.scala @@ -1,7 +1,7 @@ package net.wiringbits.actions.internal -import akka.actor.ActorSystem -import akka.stream.scaladsl.* +import org.apache.pekko.actor.ActorSystem +import org.apache.pekko.stream.scaladsl.* import net.wiringbits.repositories.BackgroundJobsRepository import net.wiringbits.repositories.models.BackgroundJobData import org.slf4j.LoggerFactory @@ -16,13 +16,13 @@ class StreamPendingBackgroundJobsForeverAction @Inject() (backgroundJobsReposito ) { private val logger = LoggerFactory.getLogger(this.getClass) - def apply(reconnectionDelay: FiniteDuration = 10.seconds): Source[BackgroundJobData, akka.NotUsed] = { + def apply(reconnectionDelay: FiniteDuration = 10.seconds): Source[BackgroundJobData, org.apache.pekko.NotUsed] = { // Let's use unfoldAsync to continuously fetch items from database // First execution doesn't involve a delay Source .unfoldAsync[Boolean, Source[BackgroundJobData, Future[Int]]](false) { delay => logger.trace(s"Looking for pending background jobs") - akka.pattern + org.apache.pekko.pattern .after(if (delay) reconnectionDelay else 0.seconds) { backgroundJobsRepository.streamPendingJobs } diff --git a/server/src/main/scala/net/wiringbits/apis/EmailApiAWSImpl.scala b/server/src/main/scala/net/wiringbits/apis/EmailApiAWSImpl.scala index f8bb6a51..1293ba46 100644 --- a/server/src/main/scala/net/wiringbits/apis/EmailApiAWSImpl.scala +++ b/server/src/main/scala/net/wiringbits/apis/EmailApiAWSImpl.scala @@ -8,7 +8,7 @@ import software.amazon.awssdk.services.ses.SesAsyncClient import software.amazon.awssdk.services.ses.model.* import javax.inject.Inject -import scala.compat.java8.FutureConverters.CompletionStageOps +import scala.jdk.FutureConverters.CompletionStageOps import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.{Future, blocking} @@ -46,7 +46,7 @@ class EmailApiAWSImpl @Inject() ( for { response <- blocking { client.sendEmail(request) - }.toScala + }.asScala _ = logger.info( s"Email sent, to: ${emailRequest.destination}, subject = ${emailRequest.message.subject}, messageId = ${response.messageId()}" ) diff --git a/server/src/main/scala/net/wiringbits/executors/DatabaseExecutionContext.scala b/server/src/main/scala/net/wiringbits/executors/DatabaseExecutionContext.scala index d654c487..c19b70bd 100644 --- a/server/src/main/scala/net/wiringbits/executors/DatabaseExecutionContext.scala +++ b/server/src/main/scala/net/wiringbits/executors/DatabaseExecutionContext.scala @@ -1,6 +1,6 @@ package net.wiringbits.executors -import akka.actor.ActorSystem +import org.apache.pekko.actor.ActorSystem import play.api.libs.concurrent.CustomExecutionContext import javax.inject.{Inject, Singleton} diff --git a/server/src/main/scala/net/wiringbits/repositories/BackgroundJobsRepository.scala b/server/src/main/scala/net/wiringbits/repositories/BackgroundJobsRepository.scala index 1ce6462d..330f6f41 100644 --- a/server/src/main/scala/net/wiringbits/repositories/BackgroundJobsRepository.scala +++ b/server/src/main/scala/net/wiringbits/repositories/BackgroundJobsRepository.scala @@ -12,7 +12,7 @@ import scala.concurrent.Future import scala.util.control.NonFatal class BackgroundJobsRepository @Inject() (database: Database)(implicit ec: DatabaseExecutionContext, clock: Clock) { - def streamPendingJobs: Future[akka.stream.scaladsl.Source[BackgroundJobData, Future[Int]]] = Future { + def streamPendingJobs: Future[org.apache.pekko.stream.scaladsl.Source[BackgroundJobData, Future[Int]]] = Future { // autocommit=false is necessary to avoid loading the whole result into memory implicit val conn = database.getConnection(autocommit = false) try { diff --git a/server/src/main/scala/net/wiringbits/repositories/daos/BackgroundJobDAO.scala b/server/src/main/scala/net/wiringbits/repositories/daos/BackgroundJobDAO.scala index 0eb13827..c449c75b 100644 --- a/server/src/main/scala/net/wiringbits/repositories/daos/BackgroundJobDAO.scala +++ b/server/src/main/scala/net/wiringbits/repositories/daos/BackgroundJobDAO.scala @@ -34,7 +34,10 @@ object BackgroundJobDAO { def streamPendingJobs( allowedErrors: Int = 10, fetchSize: Int = 1000 - )(implicit conn: Connection, clock: Clock): akka.stream.scaladsl.Source[BackgroundJobData, Future[Int]] = { + )(implicit + conn: Connection, + clock: Clock + ): org.apache.pekko.stream.scaladsl.Source[BackgroundJobData, Future[Int]] = { val query = SQL""" SELECT background_job_id, type, payload, status, status_details, error_count, execute_at, created_at, updated_at FROM background_jobs @@ -44,10 +47,7 @@ object BackgroundJobDAO { ORDER BY execute_at, background_job_id """.withFetchSize(Some(fetchSize)) // without this, all data is loaded into memory - // this requires a Materializer that isn't used, better to set a null instead of depend on a Materializer - @SuppressWarnings(Array("org.wartremover.warts.Null")) - val materializer = null - AkkaStream.source(query, backgroundJobParser)(materializer, conn) + PekkoStream.source(query, backgroundJobParser)(conn) } def setStatusToFailed(backgroundJobId: UUID, executeAt: Instant, failReason: String)(implicit diff --git a/server/src/main/scala/net/wiringbits/tasks/BackgroundJobsExecutorTask.scala b/server/src/main/scala/net/wiringbits/tasks/BackgroundJobsExecutorTask.scala index 44570762..393814ec 100644 --- a/server/src/main/scala/net/wiringbits/tasks/BackgroundJobsExecutorTask.scala +++ b/server/src/main/scala/net/wiringbits/tasks/BackgroundJobsExecutorTask.scala @@ -1,6 +1,6 @@ package net.wiringbits.tasks -import akka.actor.ActorSystem +import org.apache.pekko.actor.ActorSystem import com.google.inject.Inject import net.wiringbits.actions.internal.StreamPendingBackgroundJobsForeverAction import net.wiringbits.apis.EmailApi @@ -73,7 +73,7 @@ class BackgroundJobsExecutorTask @Inject() ( // the reason to throttle and handle 1 background job concurrently is to avoid overloading the app val result = streamPendingBackgroundJobsForeverAction() .throttle(100, 1.minute) - .runWith(akka.stream.scaladsl.Sink.foreachAsync(1)(execute)) + .runWith(org.apache.pekko.stream.scaladsl.Sink.foreachAsync(1)(execute)) result.onComplete { case Failure(ex) => diff --git a/server/src/test/scala/net/wiringbits/repositories/BackgroundJobsRepositorySpec.scala b/server/src/test/scala/net/wiringbits/repositories/BackgroundJobsRepositorySpec.scala index 928e3e51..d6c300eb 100644 --- a/server/src/test/scala/net/wiringbits/repositories/BackgroundJobsRepositorySpec.scala +++ b/server/src/test/scala/net/wiringbits/repositories/BackgroundJobsRepositorySpec.scala @@ -1,7 +1,7 @@ package net.wiringbits.repositories -import akka.actor.ActorSystem -import akka.stream.scaladsl.* +import org.apache.pekko.actor.ActorSystem +import org.apache.pekko.stream.scaladsl.* import net.wiringbits.common.models.Email import net.wiringbits.core.RepositorySpec import net.wiringbits.models.jobs.{BackgroundJobPayload, BackgroundJobStatus, BackgroundJobType} diff --git a/server/src/test/scala/net/wiringbits/repositories/UsersRepositorySpec.scala b/server/src/test/scala/net/wiringbits/repositories/UsersRepositorySpec.scala index 7f7842bc..231e5986 100644 --- a/server/src/test/scala/net/wiringbits/repositories/UsersRepositorySpec.scala +++ b/server/src/test/scala/net/wiringbits/repositories/UsersRepositorySpec.scala @@ -1,7 +1,7 @@ package net.wiringbits.repositories -import akka.actor.ActorSystem -import akka.stream.scaladsl.Sink +import org.apache.pekko.actor.ActorSystem +import org.apache.pekko.stream.scaladsl.Sink import net.wiringbits.common.models.{Email, Name} import net.wiringbits.core.RepositorySpec import net.wiringbits.repositories.models.User diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/CustomiseInterceptors.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/CustomiseInterceptors.scala deleted file mode 100644 index 64dc8f56..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/CustomiseInterceptors.scala +++ /dev/null @@ -1,138 +0,0 @@ -package sttp.tapir.server.interceptor - -import sttp.tapir.server.interceptor.content.NotAcceptableInterceptor -import sttp.tapir.server.interceptor.cors.CORSInterceptor -import sttp.tapir.server.interceptor.decodefailure.{ - DecodeFailureHandler, - DecodeFailureInterceptor, - DefaultDecodeFailureHandler -} -import sttp.tapir.server.interceptor.exception.{DefaultExceptionHandler, ExceptionHandler, ExceptionInterceptor} -import sttp.tapir.server.interceptor.log.{ServerLog, ServerLogInterceptor} -import sttp.tapir.server.interceptor.metrics.MetricsRequestInterceptor -import sttp.tapir.server.interceptor.reject.{DefaultRejectHandler, RejectHandler, RejectInterceptor} -import sttp.tapir.server.model.ValuedEndpointOutput -import sttp.tapir.statusCode - -/** Allows customising the interceptors used by the server interpreter. Custom interceptors should usually be added - * using `addInterceptor`. That way, the custom interceptor is called after the built-in ones (such as logging, - * metrics, exceptions), and before the decode failure handler. For even more flexibility, interceptors can be added to - * the beginning or end of the interceptor stack, using `prependInterceptor` and `appendInterceptor`. - * - * The first interceptor in the interceptor stack is the one which is called first on request, and processes the - * resulting response as the last one. - * - * Built-in interceptors can be customised or disabled using the dedicated methods. - * - * Once done, use [[options]] to obtain the server interpreter options objects, which can be passed to the server - * interpreter. - * - * @param prependedInterceptors - * Additional interceptors, which will be called first on request / last on response, e.g. performing logging, - * metrics, or providing alternate responses. - * @param metricsInterceptor - * Whether to collect metrics. - * @param rejectHandler - * How to respond when decoding fails for all interpreted endpoints. - * @param exceptionHandler - * Whether to respond to exceptions in the server logic, or propagate them to the server. - * @param serverLog - * The server log using which an interceptor will be created, if any. - * @param notAcceptableInterceptor - * Whether to return 406 (not acceptable) if there's no body in the endpoint's outputs, which can satisfy the - * constraints from the `Accept` header. - * @param additionalInterceptors - * Additional interceptors, which will be called before (on request) / after (on response) the `decodeFailureHandler` - * one, e.g. performing logging, metrics, or providing alternate responses. - * @param decodeFailureHandler - * The decode failure handler, from which an interceptor will be created. Determines whether to respond when an input - * fails to decode. - * @param appendedInterceptors - * Additional interceptors, which will be called last on request / first on response, e.g. handling decode failures, - * or providing alternate responses. - */ -case class CustomiseInterceptors[F[_], O]( - createOptions: CustomiseInterceptors[F, O] => O, - prependedInterceptors: List[Interceptor[F]] = Nil, - metricsInterceptor: Option[MetricsRequestInterceptor[F]] = None, - corsInterceptor: Option[CORSInterceptor[F]] = None, - rejectHandler: Option[RejectHandler[F]] = Some(DefaultRejectHandler[F]), - exceptionHandler: Option[ExceptionHandler[F]] = Some(DefaultExceptionHandler[F]), - serverLog: Option[ServerLog[F]] = None, - notAcceptableInterceptor: Option[NotAcceptableInterceptor[F]] = Some(new NotAcceptableInterceptor[F]()), - additionalInterceptors: List[Interceptor[F]] = Nil, - decodeFailureHandler: DecodeFailureHandler = DefaultDecodeFailureHandler.default, - appendedInterceptors: List[Interceptor[F]] = Nil -) { - def prependInterceptor(i: Interceptor[F]): CustomiseInterceptors[F, O] = - copy(prependedInterceptors = prependedInterceptors :+ i) - - def metricsInterceptor(m: MetricsRequestInterceptor[F]): CustomiseInterceptors[F, O] = - copy(metricsInterceptor = Some(m)) - def metricsInterceptor(m: Option[MetricsRequestInterceptor[F]]): CustomiseInterceptors[F, O] = - copy(metricsInterceptor = m) - - def corsInterceptor(c: CORSInterceptor[F]): CustomiseInterceptors[F, O] = copy(corsInterceptor = Some(c)) - def corsInterceptor(c: Option[CORSInterceptor[F]]): CustomiseInterceptors[F, O] = copy(corsInterceptor = c) - - def rejectHandler(r: RejectHandler[F]): CustomiseInterceptors[F, O] = copy(rejectHandler = Some(r)) - def rejectHandler(r: Option[RejectHandler[F]]): CustomiseInterceptors[F, O] = copy(rejectHandler = r) - - def exceptionHandler(e: ExceptionHandler[F]): CustomiseInterceptors[F, O] = copy(exceptionHandler = Some(e)) - def exceptionHandler(e: Option[ExceptionHandler[F]]): CustomiseInterceptors[F, O] = copy(exceptionHandler = e) - - def serverLog(log: ServerLog[F]): CustomiseInterceptors[F, O] = copy(serverLog = Some(log)) - def serverLog(log: Option[ServerLog[F]]): CustomiseInterceptors[F, O] = copy(serverLog = log) - - def notAcceptableInterceptor(u: NotAcceptableInterceptor[F]): CustomiseInterceptors[F, O] = - copy(notAcceptableInterceptor = Some(u)) - def notAcceptableInterceptor(u: Option[NotAcceptableInterceptor[F]]): CustomiseInterceptors[F, O] = - copy(notAcceptableInterceptor = u) - - def addInterceptor(i: Interceptor[F]): CustomiseInterceptors[F, O] = - copy(additionalInterceptors = additionalInterceptors :+ i) - - def decodeFailureHandler(d: DecodeFailureHandler): CustomiseInterceptors[F, O] = copy(decodeFailureHandler = d) - - def appendInterceptor(i: Interceptor[F]): CustomiseInterceptors[F, O] = - copy(appendedInterceptors = appendedInterceptors :+ i) - - /** Use the default exception, decode failure and reject handlers. - * @param errorMessageOutput - * customise the way error messages are shown in error responses - * @param notFoundWhenRejected - * return a 404 formatted using `errorMessageOutput` when the request was rejected by all endpoints, instead of - * propagating the rejection to the server library - */ - def defaultHandlers( - errorMessageOutput: String => ValuedEndpointOutput[_], - notFoundWhenRejected: Boolean = false - ): CustomiseInterceptors[F, O] = { - copy( - exceptionHandler = Some(DefaultExceptionHandler((s, m) => errorMessageOutput(m).prepend(statusCode, s))), - decodeFailureHandler = DefaultDecodeFailureHandler.default.response(errorMessageOutput), - rejectHandler = Some( - DefaultRejectHandler( - (s, m) => errorMessageOutput(m).prepend(statusCode, s), - if (notFoundWhenRejected) Some(DefaultRejectHandler.Responses.NotFound) else None - ) - ) - ) - } - - // - - /** Creates the default interceptor stack */ - def interceptors: List[Interceptor[F]] = prependedInterceptors ++ - metricsInterceptor.toList ++ - corsInterceptor.toList ++ - rejectHandler.map(new RejectInterceptor[F](_)).toList ++ - exceptionHandler.map(new ExceptionInterceptor[F](_)).toList ++ - serverLog.map(new ServerLogInterceptor[F](_)).toList ++ - notAcceptableInterceptor.toList ++ - additionalInterceptors ++ - List(new DecodeFailureInterceptor[F](decodeFailureHandler)) ++ - appendedInterceptors - - def options: O = createOptions(this) -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/DecodeFailureContext.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/DecodeFailureContext.scala deleted file mode 100644 index 3d6030eb..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/DecodeFailureContext.scala +++ /dev/null @@ -1,11 +0,0 @@ -package sttp.tapir.server.interceptor - -import sttp.tapir.model.ServerRequest -import sttp.tapir.{AnyEndpoint, DecodeResult, EndpointInput} - -case class DecodeFailureContext( - endpoint: AnyEndpoint, - failingInput: EndpointInput[_], - failure: DecodeResult.Failure, - request: ServerRequest -) diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/DecodeSuccessContext.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/DecodeSuccessContext.scala deleted file mode 100644 index b21bcd49..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/DecodeSuccessContext.scala +++ /dev/null @@ -1,15 +0,0 @@ -package sttp.tapir.server.interceptor - -import sttp.tapir.Endpoint -import sttp.tapir.model.ServerRequest -import sttp.tapir.server.ServerEndpoint - -case class DecodeSuccessContext[F[_], A, U, I]( - serverEndpoint: ServerEndpoint.Full[A, U, I, _, _, _, F], - securityInput: A, - principal: U, - input: I, - request: ServerRequest -) { - def endpoint: Endpoint[A, I, _, _, _] = serverEndpoint.endpoint -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/EndpointHandler.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/EndpointHandler.scala deleted file mode 100644 index bfd60863..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/EndpointHandler.scala +++ /dev/null @@ -1,53 +0,0 @@ -package sttp.tapir.server.interceptor - -import sttp.monad.MonadError -import sttp.tapir.server.interpreter.BodyListener -import sttp.tapir.server.model.ServerResponse - -/** Handles the result of decoding a request using an endpoint's inputs. */ -trait EndpointHandler[F[_], B] { - - /** Called when the request has been successfully decoded into data, and when the security logic succeeded. This is captured by the `ctx` - * parameter. - * - * Called at most once per request. - * - * @tparam A - * The type of the endpoint's security inputs. - * @tparam U - * Type of the successful result of the security logic. - * @tparam I - * The type of the endpoint's inputs. - * @return - * An effect, describing the server's response. - */ - def onDecodeSuccess[A, U, I]( - ctx: DecodeSuccessContext[F, A, U, I] - )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] - - /** Called when the security inputs have been successfully decoded into data, but the security logic failed (either with an error result - * or an exception). This is captured by the `ctx` parameter. - * - * Called at most once per request. - * - * @tparam A - * The type of the endpoint's security inputs. - * @return - * An effect, describing the server's response. - */ - def onSecurityFailure[A]( - ctx: SecurityFailureContext[F, A] - )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] - - /** Called when the given request hasn't been successfully decoded, because of the given failure on the given input. This is captured by - * the `ctx` parameter. - * - * Might be called multiple times per request. - * - * @return - * An effect, describing the optional server response. If `None`, the next endpoint will be tried (if any). - */ - def onDecodeFailure( - ctx: DecodeFailureContext - )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[Option[ServerResponse[B]]] -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/Interceptor.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/Interceptor.scala deleted file mode 100644 index b576979d..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/Interceptor.scala +++ /dev/null @@ -1,162 +0,0 @@ -package sttp.tapir.server.interceptor - -import sttp.monad.MonadError -import sttp.monad.syntax._ -import sttp.tapir.model.ServerRequest -import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.model.{ServerResponse, ValuedEndpointOutput} - -/** Intercepts requests, and endpoint decode events. Using interceptors it's possible to: - * - * - customise the request that is passed downstream - * - short-circuit further processing and provide an alternate (or no) response - * - replace or modify the response that is sent back to the client - * - * Interceptors can be called when the request is started to be processed (use [[RequestInterceptor]] in this case), or for each endpoint, - * with either input success of failure decoding events (see [[EndpointInterceptor]]). - * - * To add an interceptors, modify the server options of the server interpreter. - * - * @tparam F - * The effect type constructor. - */ -sealed trait Interceptor[F[_]] - -/** Allows intercepting the handling of `request`, before decoding using any of the endpoints is done. The request can be modified, before - * invoking further behavior, passed through `requestHandler`. Ultimately, when all interceptors are run, logic decoding subsequent - * endpoint inputs will be run. - * - * A request interceptor is called once for a request. - * - * Instead of calling the nested behavior, alternative responses can be returned using the `responder`. - * - * Moreover, when calling `requestHandler`, an [[EndpointInterceptor]] can be provided, which will be added to the list of endpoint - * interceptors to call. The order in which the endpoint interceptors will be called will correspond to their order in the interceptors - * list in the server options. An "empty" interceptor can be provided using [[EndpointInterceptor.noop]]. - * - * @tparam F - * The effect type constructor. - */ -trait RequestInterceptor[F[_]] extends Interceptor[F] { - - /** @tparam R - * The interpreter-specific supported capabilities, such as streaming support, websockets or `Any`. - * @tparam B - * The interpreter-specific, low-level type of body. - */ - def apply[R, B](responder: Responder[F, B], requestHandler: EndpointInterceptor[F] => RequestHandler[F, R, B]): RequestHandler[F, R, B] -} - -object RequestInterceptor { - - /** Create a request interceptor which transforms the server request, prior to handling any endpoints. */ - def transformServerRequest[F[_]](f: ServerRequest => F[ServerRequest]): RequestInterceptor[F] = new RequestInterceptor[F] { - override def apply[R, B]( - responder: Responder[F, B], - requestHandler: EndpointInterceptor[F] => RequestHandler[F, R, B] - ): RequestHandler[F, R, B] = - new RequestHandler[F, R, B] { - override def apply(request: ServerRequest, endpoints: List[ServerEndpoint[R, F]])(implicit - monad: MonadError[F] - ): F[RequestResult[B]] = - f(request).flatMap(request2 => requestHandler(EndpointInterceptor.noop)(request2, endpoints)) - } - } - - trait RequestResultTransform[F[_]] { - def apply[B](request: ServerRequest, result: RequestResult[B]): F[RequestResult[B]] - } - - /** Create a request interceptor which transforms the result, which might be either a response, or a list of endpoint decoding failures. - */ - def transformResult[F[_]](f: RequestResultTransform[F]): RequestInterceptor[F] = new RequestInterceptor[F] { - override def apply[R, B]( - responder: Responder[F, B], - requestHandler: EndpointInterceptor[F] => RequestHandler[F, R, B] - ): RequestHandler[F, R, B] = - new RequestHandler[F, R, B] { - override def apply(request: ServerRequest, endpoints: List[ServerEndpoint[R, F]])(implicit - monad: MonadError[F] - ): F[RequestResult[B]] = - requestHandler(EndpointInterceptor.noop)(request, endpoints).flatMap(f(request, _)) - } - } - - trait RequestResultEffectTransform[F[_]] { - def apply[B](request: ServerRequest, result: F[RequestResult[B]]): F[RequestResult[B]] - } - - /** Create a request interceptor which transforms the *effect* which computes the result (either a response, or a list of endpoint - * decoding failures), that is the `F[RequestResult[B]]` value. To transform the result itself, it might be easier to use - * [[transformResult]]. - */ - def transformResultEffect[F[_]](f: RequestResultEffectTransform[F]): RequestInterceptor[F] = new RequestInterceptor[F] { - override def apply[R, B]( - responder: Responder[F, B], - requestHandler: EndpointInterceptor[F] => RequestHandler[F, R, B] - ): RequestHandler[F, R, B] = - new RequestHandler[F, R, B] { - override def apply(request: ServerRequest, endpoints: List[ServerEndpoint[R, F]])(implicit - monad: MonadError[F] - ): F[RequestResult[B]] = - f(request, requestHandler(EndpointInterceptor.noop)(request, endpoints)) - } - } - - trait ServerEndpointFilter[F[_]] { - def apply[R](endpoints: List[ServerEndpoint[R, F]]): F[List[ServerEndpoint[R, F]]] - } - - /** Filter the server endpoints for which decoding will be later attempted, in sequence. */ - def filterServerEndpoints[F[_]](filter: ServerEndpointFilter[F]): RequestInterceptor[F] = - new RequestInterceptor[F] { - override def apply[R, B]( - responder: Responder[F, B], - requestHandler: EndpointInterceptor[F] => RequestHandler[F, R, B] - ): RequestHandler[F, R, B] = { - new RequestHandler[F, R, B] { - override def apply(request: ServerRequest, endpoints: List[ServerEndpoint[R, F]])(implicit - monad: MonadError[F] - ): F[RequestResult[B]] = { - filter(endpoints).flatMap(endpoints2 => requestHandler(EndpointInterceptor.noop)(request, endpoints2)) - } - } - } - } - - /** Run an effect when a request is received. */ - def effect[F[_]](f: ServerRequest => F[Unit]): RequestInterceptor[F] = new RequestInterceptor[F] { - override def apply[R, B]( - responder: Responder[F, B], - requestHandler: EndpointInterceptor[F] => RequestHandler[F, R, B] - ): RequestHandler[F, R, B] = - new RequestHandler[F, R, B] { - override def apply(request: ServerRequest, endpoints: List[ServerEndpoint[R, F]])(implicit - monad: MonadError[F] - ): F[RequestResult[B]] = - f(request).flatMap(_ => requestHandler(EndpointInterceptor.noop)(request, endpoints)) - } - } -} - -/** Allows intercepting the handling of a request by an endpoint, when either the endpoint's inputs have been decoded successfully, or when - * decoding has failed. Ultimately, when all interceptors are run, the endpoint's server logic will be run (in case of a decode success), - * or `None` will be returned (in case of decode failure). - * - * Instead of calling the nested behavior, alternative responses can be returned using the `responder`. - */ -trait EndpointInterceptor[F[_]] extends Interceptor[F] { - - /** @tparam B The interpreter-specific, low-level type of body. */ - def apply[B](responder: Responder[F, B], endpointHandler: EndpointHandler[F, B]): EndpointHandler[F, B] -} - -object EndpointInterceptor { - def noop[F[_]]: EndpointInterceptor[F] = new EndpointInterceptor[F] { - override def apply[B](responder: Responder[F, B], endpointHandler: EndpointHandler[F, B]): EndpointHandler[F, B] = endpointHandler - } -} - -trait Responder[F[_], B] { - def apply[O](request: ServerRequest, output: ValuedEndpointOutput[O]): F[ServerResponse[B]] -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/RequestHandler.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/RequestHandler.scala deleted file mode 100644 index 60a698d6..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/RequestHandler.scala +++ /dev/null @@ -1,18 +0,0 @@ -package sttp.tapir.server.interceptor - -import sttp.monad.MonadError -import sttp.tapir.model.ServerRequest -import sttp.tapir.server.ServerEndpoint - -trait RequestHandler[F[_], R, B] { - def apply(request: ServerRequest, endpoints: List[ServerEndpoint[R, F]])(implicit monad: MonadError[F]): F[RequestResult[B]] -} - -object RequestHandler { - def from[F[_], R, B](f: (ServerRequest, List[ServerEndpoint[R, F]], MonadError[F]) => F[RequestResult[B]]): RequestHandler[F, R, B] = - new RequestHandler[F, R, B] { - override def apply(request: ServerRequest, endpoints: List[ServerEndpoint[R, F]])(implicit - monad: MonadError[F] - ): F[RequestResult[B]] = f(request, endpoints, monad) - } -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/RequestResult.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/RequestResult.scala deleted file mode 100644 index 74391991..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/RequestResult.scala +++ /dev/null @@ -1,10 +0,0 @@ -package sttp.tapir.server.interceptor - -import sttp.tapir.server.model.ServerResponse - -/** The result of processing a request: either a response, or a list of endpoint decoding failures. */ -sealed trait RequestResult[+B] -object RequestResult { - case class Response[B](response: ServerResponse[B]) extends RequestResult[B] - case class Failure(failures: List[DecodeFailureContext]) extends RequestResult[Nothing] -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/SecurityFailureContext.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/SecurityFailureContext.scala deleted file mode 100644 index 0bd0263b..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/SecurityFailureContext.scala +++ /dev/null @@ -1,13 +0,0 @@ -package sttp.tapir.server.interceptor - -import sttp.tapir.Endpoint -import sttp.tapir.model.ServerRequest -import sttp.tapir.server.ServerEndpoint - -case class SecurityFailureContext[F[_], A]( - serverEndpoint: ServerEndpoint.Full[A, _, _, _, _, _, F], - securityInput: A, - request: ServerRequest -) { - def endpoint: Endpoint[A, _, _, _, _] = serverEndpoint.endpoint -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/content/NotAcceptableInterceptor.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/content/NotAcceptableInterceptor.scala deleted file mode 100644 index 163133ec..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/content/NotAcceptableInterceptor.scala +++ /dev/null @@ -1,47 +0,0 @@ -package sttp.tapir.server.interceptor.content - -import sttp.model.{ContentTypeRange, StatusCode} -import sttp.monad.MonadError -import sttp.tapir.internal._ -import sttp.tapir.server.interceptor._ -import sttp.tapir.server.interpreter.BodyListener -import sttp.tapir.server.model.ServerResponse -import sttp.tapir.{server, _} - -/** If no body in the endpoint's outputs satisfies the constraints from the request's `Accept` header, returns an empty response with status - * code 406, before any further processing (running the business logic) is done. - */ -class NotAcceptableInterceptor[F[_]] extends EndpointInterceptor[F] { - - override def apply[B](responder: Responder[F, B], endpointHandler: EndpointHandler[F, B]): EndpointHandler[F, B] = - new EndpointHandler[F, B] { - override def onDecodeSuccess[A, U, I]( - ctx: DecodeSuccessContext[F, A, U, I] - )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] = { - ctx.request.acceptsContentTypes match { - case _ @(Right(Nil) | Right(ContentTypeRange.AnyRange :: Nil)) => endpointHandler.onDecodeSuccess(ctx) - case Right(ranges) => - val supportedMediaTypes = ctx.endpoint.output.supportedMediaTypes - // empty supported media types -> no body is defined, so the accepts header can be ignored - val hasMatchingRepresentation = supportedMediaTypes.exists(mt => ranges.exists(mt.matches)) || supportedMediaTypes.isEmpty - - if (hasMatchingRepresentation) endpointHandler.onDecodeSuccess(ctx) - else responder(ctx.request, server.model.ValuedEndpointOutput(statusCode(StatusCode.NotAcceptable), ())) - - case Left(_) => - // we're forgiving, if we can't parse the accepts header, we try to return any response - endpointHandler.onDecodeSuccess(ctx) - } - } - - override def onSecurityFailure[A]( - ctx: SecurityFailureContext[F, A] - )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] = - endpointHandler.onSecurityFailure(ctx) - - override def onDecodeFailure( - ctx: DecodeFailureContext - )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[Option[ServerResponse[B]]] = - endpointHandler.onDecodeFailure(ctx) - } -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/cors/CORSConfig.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/cors/CORSConfig.scala deleted file mode 100644 index 01ac9f3a..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/cors/CORSConfig.scala +++ /dev/null @@ -1,174 +0,0 @@ -package sttp.tapir.server.interceptor.cors - -import sttp.model.headers.Origin -import sttp.model.{Method, StatusCode} -import sttp.tapir.server.interceptor.cors.CORSConfig._ - -import scala.concurrent.duration.Duration - -case class CORSConfig( - allowedOrigin: AllowedOrigin, - allowedCredentials: AllowedCredentials, - allowedMethods: AllowedMethods, - allowedHeaders: AllowedHeaders, - exposedHeaders: ExposedHeaders, - maxAge: MaxAge, - preflightResponseStatusCode: StatusCode -) { - - /** Allows CORS requests from any origin. - * - * Sets the `Access-Control-Allow-Origin` response header to `*`. - */ - def allowAllOrigins: CORSConfig = copy(allowedOrigin = AllowedOrigin.All) - - /** Allows CORS requests only from a specific origin. - * - * If the request `Origin` matches the given `origin`, sets the `Access-Control-Allow-Origin` response header to the given `origin`. - * Otherwise the `Access-Control-Allow-Origin` response header is suppressed. - */ - def allowOrigin(origin: Origin): CORSConfig = copy(allowedOrigin = AllowedOrigin.Single(origin)) - - /** Allows CORS requests from origins matching predicate - * - * If the request `Origin` header matches the given `predicate`, sets the `Access-Control-Allow-Origin` response header to the given - * `origin`. Otherwise the `Access-Control-Allow-Origin` response header is suppressed. - */ - def allowMatchingOrigins(predicate: String => Boolean): CORSConfig = copy(allowedOrigin = AllowedOrigin.Matching(predicate)) - - /** Allows credentialed requests by setting the `Access-Control-Allow-Credentials` response header to `true` - */ - def allowCredentials: CORSConfig = copy(allowedCredentials = AllowedCredentials.Allow) - - /** Blocks credentialed requests by suppressing the `Access-Control-Allow-Credentials` response header - */ - def denyCredentials: CORSConfig = copy(allowedCredentials = AllowedCredentials.Deny) - - /** Allows CORS requests using any method. - * - * Sets the `Access-Control-Allow-Methods` response header to `*` - */ - def allowAllMethods: CORSConfig = copy(allowedMethods = AllowedMethods.All) - - /** Allows CORS requests using specific methods. - * - * If the preflight request's `Access-Control-Request-Method` header requests one of the specified `methods`, the - * `Access-Control-Allow-Methods` response header is set to a comma-separated list of the given `methods`. Otherwise the - * `Access-Control-Allow-Methods` response header is suppressed. - */ - def allowMethods(methods: Method*): CORSConfig = copy(allowedMethods = AllowedMethods.Some(methods.toSet)) - - /** Allows CORS requests with any headers. - * - * Sets the `Access-Control-Allow-Headers` response header to `*` - */ - def allowAllHeaders: CORSConfig = copy(allowedHeaders = AllowedHeaders.All) - - /** Allows CORS requests using specific headers. - * - * If the preflight request's `Access-Control-Request-Headers` header requests one of the specified `headers`, the - * `Access-Control-Allow-Headers` response header is set to a comma-separated list of the given `headers`. Otherwise the - * `Access-Control-Allow-Headers` response header is suppressed. - */ - def allowHeaders(headerNames: String*): CORSConfig = copy(allowedHeaders = AllowedHeaders.Some(headerNames.toSet)) - - /** Allows CORS requests using any headers requested in preflight request's `Access-Control-Request-Headers` header. - * - * Use [[reflectHeaders]] instead of [[allowAllHeaders]] when credentialed requests are enabled with [[allowCredentials]], since - * wildcards are illegal when credentials are enabled - */ - def reflectHeaders: CORSConfig = copy(allowedHeaders = AllowedHeaders.Reflect) - - /** Exposes all response headers to JavaScript in browsers - * - * Sets the `Access-Control-Expose-Headers` response header to `*` - */ - def exposeAllHeaders: CORSConfig = copy(exposedHeaders = ExposedHeaders.All) - - /** Exposes no response headers to JavaScript in browsers - * - * Suppresses the `Access-Control-Expose-Headers` response header - */ - def exposeNoHeaders: CORSConfig = copy(exposedHeaders = ExposedHeaders.None) - - /** Exposes specific response headers to JavaScript in browsers - * - * Sets the `Access-Control-Expose-Headers` response header to a comma-separated list of the given `headerNames` - */ - def exposeHeaders(headerNames: String*): CORSConfig = copy(exposedHeaders = ExposedHeaders.Some(headerNames.toSet)) - - /** Determines how long the response to a preflight request can be cached by the client. - * - * Suppresses the `Access-Control-Max-Age` response header, which makes the client use its default value. - */ - def defaultMaxAge: CORSConfig = copy(maxAge = MaxAge.Default) - - /** Determines how long the response to a preflight request can be cached by the client. - * - * Sets the `Access-Control-Max-Age` response header to the given `duration` in seconds. - */ - def maxAge(duration: Duration): CORSConfig = copy(maxAge = MaxAge.Some(duration)) - - /** Sets the response status code of successful preflight requests to "204 No Content" - */ - def defaultPreflightResponseStatusCode: CORSConfig = copy(preflightResponseStatusCode = StatusCode.NoContent) - - /** Sets the response status code of successful preflight requests to the given `statusCode` - */ - def preflightResponseStatusCode(statusCode: StatusCode): CORSConfig = copy(preflightResponseStatusCode = statusCode) - - /** When credentialed requests are enabled, any wildcard in allowed origin/headers/methods is illegal */ - private[cors] def isValid: Boolean = - allowedCredentials == AllowedCredentials.Deny || (allowedOrigin != AllowedOrigin.All && allowedHeaders != AllowedHeaders.All && allowedMethods != AllowedMethods.All) -} - -object CORSConfig { - val default: CORSConfig = CORSConfig( - allowedOrigin = AllowedOrigin.All, - allowedCredentials = AllowedCredentials.Deny, - allowedMethods = AllowedMethods.Some(Set(Method.GET, Method.HEAD, Method.POST, Method.PUT, Method.DELETE)), - allowedHeaders = AllowedHeaders.Reflect, - exposedHeaders = ExposedHeaders.None, - maxAge = MaxAge.Default, - preflightResponseStatusCode = StatusCode.NoContent - ) - - sealed trait AllowedOrigin - object AllowedOrigin { - case object All extends AllowedOrigin - case class Single(origin: Origin) extends AllowedOrigin - case class Matching(predicate: String => Boolean) extends AllowedOrigin - } - - sealed trait AllowedCredentials - object AllowedCredentials { - case object Allow extends AllowedCredentials - case object Deny extends AllowedCredentials - } - - sealed trait AllowedMethods - object AllowedMethods { - case object All extends AllowedMethods - case class Some(methods: Set[Method]) extends AllowedMethods - } - - sealed trait AllowedHeaders - object AllowedHeaders { - case object All extends AllowedHeaders - case class Some(headersNames: Set[String]) extends AllowedHeaders - case object Reflect extends AllowedHeaders - } - - sealed trait ExposedHeaders - object ExposedHeaders { - case object All extends ExposedHeaders - case class Some(headerNames: Set[String]) extends ExposedHeaders - case object None extends ExposedHeaders - } - - sealed trait MaxAge - object MaxAge { - case class Some(duration: Duration) extends MaxAge - case object Default extends MaxAge - } -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/cors/CORSInterceptor.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/cors/CORSInterceptor.scala deleted file mode 100644 index e904c023..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/cors/CORSInterceptor.scala +++ /dev/null @@ -1,163 +0,0 @@ -package sttp.tapir.server.interceptor.cors - -import sttp.model.{Header, HeaderNames, Method} -import sttp.monad.MonadError -import sttp.monad.syntax.MonadErrorOps -import sttp.tapir.model.ServerRequest -import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.interceptor.RequestResult.Response -import sttp.tapir.server.interceptor.cors.CORSConfig._ -import sttp.tapir.server.interceptor.{EndpointInterceptor, RequestHandler, RequestInterceptor, RequestResult, Responder} -import sttp.tapir.server.model.ServerResponse - -class CORSInterceptor[F[_]] private (config: CORSConfig) extends RequestInterceptor[F] { - override def apply[R, B]( - responder: Responder[F, B], - requestHandler: EndpointInterceptor[F] => RequestHandler[F, R, B] - ): RequestHandler[F, R, B] = - new RequestHandler[F, R, B] { - override def apply(request: ServerRequest, endpoints: List[ServerEndpoint[R, F]])(implicit me: MonadError[F]): F[RequestResult[B]] = { - val handler = request.header(HeaderNames.Origin).map(cors).getOrElse(nonCors(_, _)) - handler(request, endpoints) - } - - private val next = requestHandler(EndpointInterceptor.noop) - - private def cors( - origin: String - )(request: ServerRequest, endpoints: List[ServerEndpoint[R, F]])(implicit me: MonadError[F]): F[RequestResult[B]] = { - def headerValues(rawHeader: String): Set[String] = rawHeader.split("\\s*,\\s*").toSet - - def preflight(requestedHeaderNames: Set[String], requestedMethodName: String): F[RequestResult[B]] = { - val requestedMethod = Method.safeApply(requestedMethodName) match { - case Left(_) => None - case Right(method) => Some(method) - } - - val responseHeaders = List( - ResponseHeaders.allowOrigin(origin), - requestedMethod.flatMap(ResponseHeaders.allowMethods), - ResponseHeaders.allowHeaders(requestedHeaderNames), - ResponseHeaders.allowCredentials, - ResponseHeaders.maxAge, - ResponseHeaders.varyPreflight - ).flatten - - me.unit(Response(ServerResponse(config.preflightResponseStatusCode, responseHeaders, None, None))) - } - - def nonPreflight: F[RequestResult[B]] = { - val responseHeaders = List( - ResponseHeaders.allowOrigin(origin), - ResponseHeaders.allowCredentials, - ResponseHeaders.exposeHeaders, - ResponseHeaders.varyNonPreflight - ).flatten - - next(request, endpoints).map { - case Response(serverResponse) => Response(serverResponse.addHeaders(responseHeaders)) - case failure => failure - } - } - - request.method match { - case Method.OPTIONS => - request.header(HeaderNames.AccessControlRequestMethod) match { - case Some(requestedMethodName) => - val requestedHeaderNames = request.header(HeaderNames.AccessControlRequestHeaders).map(headerValues).getOrElse(Set.empty) - preflight(requestedHeaderNames, requestedMethodName) - case None => - nonPreflight - } - case _ => nonPreflight - } - } - - private def nonCors(request: ServerRequest, endpoints: List[ServerEndpoint[R, F]])(implicit - monad: MonadError[F] - ): F[RequestResult[B]] = next(request, endpoints) - } - - private[cors] object ResponseHeaders { - private val Wildcard = "*" - private val AnyMethod = Method(Wildcard) - private val AllowAnyOrigin = Header.accessControlAllowOrigin(Wildcard) - private val AllowAnyHeaders = Header.accessControlAllowHeaders(Wildcard) - private val ExposeAllHeaders = Header.accessControlExposeHeaders(Wildcard) - - def allowOrigin(origin: String): Option[Header] = config.allowedOrigin match { - case AllowedOrigin.All => Some(AllowAnyOrigin) - case AllowedOrigin.Single(allowedOrigin) if origin.equalsIgnoreCase(allowedOrigin.toString) => - Some(Header.accessControlAllowOrigin(origin)) - case AllowedOrigin.Matching(predicate) if predicate(origin) => - Some(Header.accessControlAllowOrigin(origin)) - case _ => None - } - - def allowCredentials: Option[Header] = config.allowedCredentials match { - case AllowedCredentials.Allow => Some(Header.accessControlAllowCredentials(true)) - case AllowedCredentials.Deny => None - } - - def allowMethods(method: Method): Option[Header] = config.allowedMethods match { - case AllowedMethods.All => Some(Header.accessControlAllowMethods(AnyMethod)) - case AllowedMethods.Some(methods) if methods.exists(_.is(method)) => Some(Header.accessControlAllowMethods(methods.toList: _*)) - case _ => None - } - - def allowHeaders(requestHeaderNames: Set[String]): Option[Header] = config.allowedHeaders match { - case AllowedHeaders.All => Some(AllowAnyHeaders) - case AllowedHeaders.Some(headerNames) => Some(Header.accessControlAllowHeaders(headerNames.toList: _*)) - case AllowedHeaders.Reflect => Some(Header.accessControlAllowHeaders(requestHeaderNames.toList: _*)) - } - - def exposeHeaders: Option[Header] = config.exposedHeaders match { - case ExposedHeaders.All => Some(ExposeAllHeaders) - case ExposedHeaders.Some(headerNames) => Some(Header.accessControlExposeHeaders(headerNames.toList: _*)) - case ExposedHeaders.None => None - } - - def maxAge: Option[Header] = config.maxAge match { - case MaxAge.Some(duration) => Some(Header.accessControlMaxAge(duration.toSeconds)) - case MaxAge.Default => None - } - - def varyPreflight: Option[Header] = { - val origin = config.allowedOrigin match { - case AllowedOrigin.All => Nil - case AllowedOrigin.Single(_) | AllowedOrigin.Matching(_) => List(HeaderNames.Origin) - } - - val methods = config.allowedMethods match { - case AllowedMethods.All => Nil - case AllowedMethods.Some(_) => List(HeaderNames.AccessControlRequestMethod) - } - - val headers = config.allowedHeaders match { - case AllowedHeaders.All => Nil - case AllowedHeaders.Some(_) | AllowedHeaders.Reflect => List(HeaderNames.AccessControlRequestHeaders) - } - - origin ++ methods ++ headers match { - case Nil => None - case headerNames => Some(Header.vary(headerNames: _*)) - } - } - - def varyNonPreflight: Option[Header] = config.allowedOrigin match { - case AllowedOrigin.All => None - case AllowedOrigin.Single(_) | AllowedOrigin.Matching(_) => Some(Header.vary(HeaderNames.Origin)) - } - } -} - -object CORSInterceptor { - def default[F[_]]: CORSInterceptor[F] = new CORSInterceptor[F](CORSConfig.default) - - def customOrThrow[F[_]](customConfig: CORSConfig): CORSInterceptor[F] = - if (customConfig.isValid) new CORSInterceptor[F](customConfig) - else - throw new IllegalArgumentException( - "Illegal CORS config. For security reasons, when allowCredentials is set to Allow, none of: allowOrigin, allowHeaders, allowMethods can be set to All." - ) -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureHandler.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureHandler.scala deleted file mode 100644 index 0898bfb0..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureHandler.scala +++ /dev/null @@ -1,303 +0,0 @@ -package sttp.tapir.server.interceptor.decodefailure - -import sttp.model.{Header, HeaderNames, StatusCode} -import sttp.tapir.DecodeResult.Error.{JsonDecodeException, MultipartDecodeException} -import sttp.tapir.DecodeResult._ -import sttp.tapir.internal.RichEndpoint -import sttp.tapir.server.interceptor.DecodeFailureContext -import sttp.tapir.server.model.ValuedEndpointOutput -import sttp.tapir.{DecodeResult, EndpointIO, EndpointInput, ValidationError, Validator, server, _} - -import scala.annotation.tailrec - -trait DecodeFailureHandler { - - /** Given the context in which a decode failure occurred (the request, the input and the failure), returns an optional response to the - * request. `None` indicates that no action should be taken, and the request might be passed for decoding to other endpoints. - * - * Inputs are decoded in the following order: path, method, query, headers, body. Hence, if there's a decode failure on a query - * parameter, any method & path inputs of the input must have matched and must have been decoded successfully. - */ - def apply(ctx: DecodeFailureContext): Option[ValuedEndpointOutput[_]] -} - -/** A decode failure handler, which: - * - decides whether the given decode failure should lead to a response (and if so, with which status code and headers), using `respond` - * - in case a response is sent, creates the message using `failureMessage` - * - in case a response is sent, creates the response using `response`, given the status code, headers, and the created failure message. - * By default, the headers might include authentication challenge. - */ -case class DefaultDecodeFailureHandler( - respond: DecodeFailureContext => Option[(StatusCode, List[Header])], - failureMessage: DecodeFailureContext => String, - response: (StatusCode, List[Header], String) => ValuedEndpointOutput[_] -) extends DecodeFailureHandler { - def apply(ctx: DecodeFailureContext): Option[ValuedEndpointOutput[_]] = { - respond(ctx) match { - case Some((sc, hs)) => - val failureMsg = failureMessage(ctx) - Some(response(sc, hs, failureMsg)) - case None => None - } - } - - def response(messageOutput: String => ValuedEndpointOutput[_]): DefaultDecodeFailureHandler = - copy(response = (s, h, m) => messageOutput(m).prepend(statusCode.and(headers), (s, h))) -} - -object DefaultDecodeFailureHandler { - - /** The default implementation of the [[DecodeFailureHandler]]. - * - * A 400 (bad request) is returned if a query, header or body input can't be decoded (for any reason), or if decoding a path capture - * causes a validation error. - * - * A 401 (unauthorized) is returned when an authentication input (created using [[Tapir.auth]]) cannot be decoded. The appropriate - * `WWW-Authenticate` headers are included. - * - * Otherwise (e.g. if the method, a path segment, or path capture is missing, there's a mismatch or a decode error), `None` is returned, - * which is a signal to try the next endpoint. - * - * The error messages contain information about the source of the decode error, and optionally the validation error detail that caused - * the failure. - * - * The default decode failure handler can be customised by providing alternate functions for deciding whether a response should be sent, - * creating the error message and creating the response. - * - * Furthermore, how decode failures are handled can be adjusted globally by changing the flags passed to [[respond]]. By default, if the - * shape of the path for an endpoint matches the request, but decoding a path capture causes an error (e.g. a `path[Int]("amount")` - * cannot be parsed), the next endpoint is tried. However, if there's a validation error (e.g. a `path[Kind]("kind")`, where `Kind` is an - * enum, and a value outside the enumeration values is provided), a 400 response is sent. - * - * Finally, behavior can be adjusted per-endpoint-input, by setting an attribute. Import the [[OnDecodeFailure]] object and use the - * [[OnDecodeFailure.RichEndpointTransput.onDecodeFailureNextEndpoint]] extension method. - * - * This is only used for failures that occur when decoding inputs, not for exceptions that happen when the server logic is invoked. - * Exceptions can be either handled by the server logic, and converted to an error output value. Uncaught exceptions can be handled using - * the [[sttp.tapir.server.interceptor.exception.ExceptionInterceptor]]. - */ - val default: DefaultDecodeFailureHandler = DefaultDecodeFailureHandler( - respond(_), - FailureMessages.failureMessage, - failureResponse - ) - - /** A [[default]] handler which responds with a `404 Not Found`, instead of a `401 Unauthorized` or `400 Bad Request`, in case any input - * fails to decode, and the endpoint contains authentication inputs (created using [[Tapir.auth]]). No `WWW-Authenticate` headers are - * sent. - * - * Hence, the information if the endpoint exists, but needs authentication is hidden from the client. However, the existence of the - * endpoint might still be revealed using timing attacks. - */ - val hideEndpointsWithAuth: DefaultDecodeFailureHandler = - default.copy(respond = ctx => respondNotFoundIfHasAuth(ctx, default.respond(ctx))) - - def failureResponse(c: StatusCode, hs: List[Header], m: String): ValuedEndpointOutput[_] = - server.model.ValuedEndpointOutput(statusCode.and(headers).and(stringBody), (c, hs, m)) - - def respond( - ctx: DecodeFailureContext - ): Option[(StatusCode, List[Header])] = { - (failingInput(ctx), ctx.failure) match { - case (i: EndpointTransput.Atom[_], _) if i.attribute(OnDecodeFailure.key).contains(OnDecodeFailureNextEndpointAttribute()) => None - case (_: EndpointInput.Query[_], _) => respondBadRequest - case (_: EndpointInput.QueryParams[_], _) => respondBadRequest - case (_: EndpointInput.Cookie[_], _) => respondBadRequest - case (h: EndpointIO.Header[_], _: DecodeResult.Mismatch) if h.name == HeaderNames.ContentType => - respondUnsupportedMediaType - case (_: EndpointIO.Header[_], _) => respondBadRequest - case (fh: EndpointIO.FixedHeader[_], _: DecodeResult.Mismatch) if fh.h.name == HeaderNames.ContentType => - respondUnsupportedMediaType - case (_: EndpointIO.FixedHeader[_], _) => respondBadRequest - case (_: EndpointIO.Headers[_], _) => respondBadRequest - case (_: EndpointIO.Body[_, _], _) => respondBadRequest - case (_: EndpointIO.OneOfBody[_, _], _: DecodeResult.Mismatch) => respondUnsupportedMediaType - case (_: EndpointIO.StreamBodyWrapper[_, _], _) => respondBadRequest - // we assume that the only decode failure that might happen during path segment decoding is an error - // a non-standard path decoder might return Missing/Multiple/Mismatch, but that would be indistinguishable from - // a path shape mismatch - case (_: EndpointInput.PathCapture[_], _: DecodeResult.Error | _: DecodeResult.InvalidValue) => - respondBadRequest - case (_: EndpointInput.PathsCapture[_], _) => respondBadRequest - // if the failing input contains an authentication input (potentially nested), sending its challenge - case (FirstAuth(a), _) => Some((StatusCode.Unauthorized, Header.wwwAuthenticate(a.challenge))) - // other basic endpoints - the request doesn't match, but not returning a response (trying other endpoints) - case (_: EndpointInput.Basic[_], _) => None - // all other inputs (tuples, mapped) - responding with bad request - case _ => respondBadRequest - } - } - private val respondBadRequest = Some(onlyStatus(StatusCode.BadRequest)) - private val respondUnsupportedMediaType = Some(onlyStatus(StatusCode.UnsupportedMediaType)) - - def respondNotFoundIfHasAuth( - ctx: DecodeFailureContext, - response: Option[(StatusCode, List[Header])] - ): Option[(StatusCode, List[Header])] = response.map { r => - val e = ctx.endpoint - if (e.auths.nonEmpty) { - // all responses (both 400 and 401) are converted to a not-found - onlyStatus(StatusCode.NotFound) - } else r - } - - private def onlyStatus(status: StatusCode): (StatusCode, List[Header]) = (status, Nil) - - private def failingInput(ctx: DecodeFailureContext) = { - import sttp.tapir.internal.RichEndpointInput - ctx.failure match { - case DecodeResult.Missing => - def missingAuth(i: EndpointInput[_]) = i.pathTo(ctx.failingInput).collectFirst { case a: EndpointInput.Auth[_, _] => - a - } - missingAuth(ctx.endpoint.securityInput).orElse(missingAuth(ctx.endpoint.input)).getOrElse(ctx.failingInput) - case _ => ctx.failingInput - } - } - - private object FirstAuth { - def unapply(input: EndpointInput[_]): Option[EndpointInput.Auth[_, _]] = input match { - case a: EndpointInput.Auth[_, _] => Some(a) - case EndpointInput.MappedPair(input, _) => unapply(input) - case EndpointIO.MappedPair(input, _) => unapply(input) - case EndpointInput.Pair(left, right, _, _) => unapply(left).orElse(unapply(right)) - case EndpointIO.Pair(left, right, _, _) => unapply(left).orElse(unapply(right)) - case _ => None - } - } - - /** Default messages for [[DecodeResult.Failure]] s. */ - object FailureMessages { - - /** Describes the source of the failure: in which part of the request did the failure occur. */ - @tailrec - def failureSourceMessage(input: EndpointInput[_]): String = - input match { - case EndpointInput.FixedMethod(_, _, _) => s"Invalid value for: method" - case EndpointInput.FixedPath(_, _, _) => s"Invalid value for: path segment" - case EndpointInput.PathCapture(name, _, _) => s"Invalid value for: path parameter ${name.getOrElse("?")}" - case EndpointInput.PathsCapture(_, _) => s"Invalid value for: path" - case EndpointInput.Query(name, _, _, _) => s"Invalid value for: query parameter $name" - case EndpointInput.QueryParams(_, _) => "Invalid value for: query parameters" - case EndpointInput.Cookie(name, _, _) => s"Invalid value for: cookie $name" - case _: EndpointInput.ExtractFromRequest[_] => "Invalid value" - case a: EndpointInput.Auth[_, _] => failureSourceMessage(a.input) - case _: EndpointInput.MappedPair[_, _, _, _] => "Invalid value" - case _: EndpointIO.Body[_, _] => s"Invalid value for: body" - case _: EndpointIO.StreamBodyWrapper[_, _] => s"Invalid value for: body" - case EndpointIO.Header(name, _, _) => s"Invalid value for: header $name" - case EndpointIO.FixedHeader(name, _, _) => s"Invalid value for: header $name" - case EndpointIO.Headers(_, _) => s"Invalid value for: headers" - case _ => "Invalid value" - } - - def failureDetailMessage(failure: DecodeResult.Failure): Option[String] = failure match { - case InvalidValue(errors) if errors.nonEmpty => Some(ValidationMessages.validationErrorsMessage(errors)) - case Error(_, JsonDecodeException(errors, _)) if errors.nonEmpty => - Some( - errors - .map { error => - val at = if (error.path.nonEmpty) s" at '${error.path.map(_.encodedName).mkString(".")}'" else "" - error.msg + at - } - .mkString(", ") - ) - case Error(_, MultipartDecodeException(partFailures)) => - Some( - partFailures - .map { case (partName, partDecodeFailure) => - combineSourceAndDetail(s"part: $partName", failureDetailMessage(partDecodeFailure)) - } - .mkString(", ") - ) - case Missing => Some("missing") - case Multiple(_) => Some("multiple values") - case Mismatch(_, _) => Some("value mismatch") - case _ => None - } - - def combineSourceAndDetail(source: String, detail: Option[String]): String = - detail match { - case None => source - case Some(d) => s"$source ($d)" - } - - /** Default message describing the source of a decode failure, alongside with optional validation/decode failure details. */ - def failureMessage(ctx: DecodeFailureContext): String = { - val base = failureSourceMessage(ctx.failingInput) - val detail = failureDetailMessage(ctx.failure) - combineSourceAndDetail(base, detail) - } - } - - /** Default messages when the decode failure is due to a validation error. */ - object ValidationMessages { - - /** Default message describing why a value is invalid. - * @param valueName - * Name of the validated value to be used in error messages - */ - def invalidValueMessage[T](ve: ValidationError[T], valueName: String): String = { - ve.customMessage match { - case Some(message) => s"expected $valueName to pass validation: $message, but got: ${ve.invalidValue}" - case None => - ve.validator match { - case Validator.Min(value, exclusive) => - s"expected $valueName to be greater than ${if (exclusive) "" else "or equal to "}$value, but got ${ve.invalidValue}" - case Validator.Max(value, exclusive) => - s"expected $valueName to be less than ${if (exclusive) "" else "or equal to "}$value, but got ${ve.invalidValue}" - // TODO: convert to patterns when https://github.com/lampepfl/dotty/issues/12226 is fixed - case p: Validator.Pattern[T] => s"expected $valueName to match: ${p.value}, but got: ${quoteIfString(ve.invalidValue)}" - case m: Validator.MinLength[T] => - s"expected $valueName to have length greater than or equal to ${m.value}, but got: ${quoteIfString(ve.invalidValue)}" - case m: Validator.MaxLength[T] => - s"expected $valueName to have length less than or equal to ${m.value}, but got: ${quoteIfString(ve.invalidValue)}" - case m: Validator.MinSize[T, Iterable] @unchecked => - s"expected size of $valueName to be greater than or equal to ${m.value}, but got ${size(ve.invalidValue)}" - case m: Validator.MaxSize[T, Iterable] @unchecked => - s"expected size of $valueName to be less than or equal to ${m.value}, but got ${size(ve.invalidValue)}" - case Validator.Custom(_, _) => s"expected $valueName to pass validation, but got: ${quoteIfString(ve.invalidValue)}" - case Validator.Enumeration(possibleValues, encode, _) => - val encodedPossibleValues = - encode.fold(possibleValues.map(_.toString))(e => possibleValues.flatMap(e(_).toList).map(_.toString)) - s"expected $valueName to be one of ${encodedPossibleValues.mkString("(", ", ", ")")}, but got: ${quoteIfString(ve.invalidValue)}" - } - } - } - - /** Default message describing the path to an invalid value. This is the path inside the validated object, e.g. - * `user.address.street.name`. - */ - def pathMessage(path: List[FieldName]): Option[String] = - path match { - case Nil => None - case l => Some(l.map(_.encodedName).mkString(".")) - } - - /** Default message describing the validation error: which value is invalid, and why. */ - def validationErrorMessage(ve: ValidationError[_]): String = invalidValueMessage(ve, pathMessage(ve.path).getOrElse("value")) - - /** Default message describing a list of validation errors: which values are invalid, and why. */ - def validationErrorsMessage(ve: List[ValidationError[_]]): String = ve.map(validationErrorMessage).mkString(", ") - - private def quoteIfString(v: Any): Any = v match { - case s: String => s""""$s"""" - case _ => v - } - - private def size(v: Any): Any = v match { - case i: Iterable[_] => i.size - case _ => v - } - } - - private[decodefailure] case class OnDecodeFailureNextEndpointAttribute() - - object OnDecodeFailure { - private[decodefailure] val key: AttributeKey[OnDecodeFailureNextEndpointAttribute] = AttributeKey[OnDecodeFailureNextEndpointAttribute] - - implicit class RichEndpointTransput[ET <: EndpointTransput.Atom[_]](val et: ET) extends AnyVal { - def onDecodeFailureNextEndpoint: ET = et.attribute(key, OnDecodeFailureNextEndpointAttribute()).asInstanceOf[ET] - } - } -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureInterceptor.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureInterceptor.scala deleted file mode 100644 index 936b1fee..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureInterceptor.scala +++ /dev/null @@ -1,31 +0,0 @@ -package sttp.tapir.server.interceptor.decodefailure - -import sttp.monad.MonadError -import sttp.monad.syntax._ -import sttp.tapir.server.interceptor._ -import sttp.tapir.server.interpreter.BodyListener -import sttp.tapir.server.model.ServerResponse - -class DecodeFailureInterceptor[F[_]](handler: DecodeFailureHandler) extends EndpointInterceptor[F] { - override def apply[B](responder: Responder[F, B], endpointHandler: EndpointHandler[F, B]): EndpointHandler[F, B] = - new EndpointHandler[F, B] { - override def onDecodeSuccess[A, U, I]( - ctx: DecodeSuccessContext[F, A, U, I] - )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] = - endpointHandler.onDecodeSuccess(ctx) - - override def onSecurityFailure[A]( - ctx: SecurityFailureContext[F, A] - )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] = - endpointHandler.onSecurityFailure(ctx) - - override def onDecodeFailure( - ctx: DecodeFailureContext - )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[Option[ServerResponse[B]]] = { - handler(ctx) match { - case None => endpointHandler.onDecodeFailure(ctx) - case Some(valuedOutput) => responder(ctx.request, valuedOutput).map(Some(_)) - } - } - } -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/exception/ExceptionContext.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/exception/ExceptionContext.scala deleted file mode 100644 index 7b6975db..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/exception/ExceptionContext.scala +++ /dev/null @@ -1,6 +0,0 @@ -package sttp.tapir.server.interceptor.exception - -import sttp.tapir.AnyEndpoint -import sttp.tapir.model.ServerRequest - -case class ExceptionContext(e: Throwable, endpoint: AnyEndpoint, request: ServerRequest) diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/exception/ExceptionHandler.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/exception/ExceptionHandler.scala deleted file mode 100644 index b9c1a0a6..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/exception/ExceptionHandler.scala +++ /dev/null @@ -1,34 +0,0 @@ -package sttp.tapir.server.interceptor.exception - -import sttp.model.StatusCode -import sttp.monad.MonadError -import sttp.tapir.server.model.ValuedEndpointOutput -import sttp.tapir._ - -trait ExceptionHandler[F[_]] { - def apply(ctx: ExceptionContext)(implicit monad: MonadError[F]): F[Option[ValuedEndpointOutput[_]]] -} - -object ExceptionHandler { - def apply[F[_]](f: ExceptionContext => F[Option[ValuedEndpointOutput[_]]]): ExceptionHandler[F] = - new ExceptionHandler[F] { - override def apply(ctx: ExceptionContext)(implicit monad: MonadError[F]): F[Option[ValuedEndpointOutput[_]]] = - f(ctx) - } - - def pure[F[_]](f: ExceptionContext => Option[ValuedEndpointOutput[_]]): ExceptionHandler[F] = - new ExceptionHandler[F] { - override def apply(ctx: ExceptionContext)(implicit monad: MonadError[F]): F[Option[ValuedEndpointOutput[_]]] = - monad.unit(f(ctx)) - } -} - -case class DefaultExceptionHandler[F[_]](response: (StatusCode, String) => ValuedEndpointOutput[_]) extends ExceptionHandler[F] { - override def apply(ctx: ExceptionContext)(implicit monad: MonadError[F]): F[Option[ValuedEndpointOutput[_]]] = - monad.unit(Some(response(StatusCode.InternalServerError, "Internal server error"))) -} - -object DefaultExceptionHandler { - def apply[F[_]]: ExceptionHandler[F] = - DefaultExceptionHandler[F]((code: StatusCode, body: String) => ValuedEndpointOutput(statusCode.and(stringBody), (code, body))) -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/exception/ExceptionInterceptor.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/exception/ExceptionInterceptor.scala deleted file mode 100644 index d0e6a7d2..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/exception/ExceptionInterceptor.scala +++ /dev/null @@ -1,47 +0,0 @@ -package sttp.tapir.server.interceptor.exception - -import sttp.monad.MonadError -import sttp.monad.syntax._ -import sttp.tapir.AnyEndpoint -import sttp.tapir.model.ServerRequest -import sttp.tapir.server.interceptor._ -import sttp.tapir.server.interpreter.BodyListener -import sttp.tapir.server.model.ServerResponse - -import scala.util.control.NonFatal - -class ExceptionInterceptor[F[_]](handler: ExceptionHandler[F]) extends EndpointInterceptor[F] { - override def apply[B](responder: Responder[F, B], decodeHandler: EndpointHandler[F, B]): EndpointHandler[F, B] = - new EndpointHandler[F, B] { - override def onDecodeSuccess[A, U, I]( - ctx: DecodeSuccessContext[F, A, U, I] - )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] = { - monad.handleError(decodeHandler.onDecodeSuccess(ctx)) { case NonFatal(e) => - onException(e, ctx.endpoint, ctx.request) - } - } - - override def onSecurityFailure[A]( - ctx: SecurityFailureContext[F, A] - )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] = - monad.handleError(decodeHandler.onSecurityFailure(ctx)) { case NonFatal(e) => - onException(e, ctx.endpoint, ctx.request) - } - - override def onDecodeFailure( - ctx: DecodeFailureContext - )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[Option[ServerResponse[B]]] = { - monad.handleError(decodeHandler.onDecodeFailure(ctx)) { case NonFatal(e) => - onException(e, ctx.endpoint, ctx.request).map(Some(_)) - } - } - - private def onException(e: Throwable, endpoint: AnyEndpoint, request: ServerRequest)(implicit - monad: MonadError[F] - ): F[ServerResponse[B]] = - handler(ExceptionContext(e, endpoint, request)).flatMap { - case Some(output) => responder(request, output) - case None => monad.error(e) - } - } -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/log/ExceptionContext.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/log/ExceptionContext.scala deleted file mode 100644 index b135a700..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/log/ExceptionContext.scala +++ /dev/null @@ -1,11 +0,0 @@ -package sttp.tapir.server.interceptor.log - -import sttp.tapir.Endpoint -import sttp.tapir.model.ServerRequest - -case class ExceptionContext[A, U]( - endpoint: Endpoint[A, _, _, _, _], - securityInput: Option[A], - principal: Option[U], - request: ServerRequest -) diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/log/ServerLog.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/log/ServerLog.scala deleted file mode 100644 index 528d00f5..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/log/ServerLog.scala +++ /dev/null @@ -1,142 +0,0 @@ -package sttp.tapir.server.interceptor.log - -import sttp.tapir.model.ServerRequest -import sttp.tapir.server.interceptor.{DecodeFailureContext, DecodeSuccessContext, SecurityFailureContext} -import sttp.tapir.server.model.ServerResponse -import sttp.tapir.{AnyEndpoint, DecodeResult} - -import java.time.Clock - -/** Used by [[ServerLogInterceptor]] to log how a request was handled. - * @tparam F[_] - * Interpreter-specific effect type constructor. - */ -trait ServerLog[F[_]] { - - /** The type of the per-request token that is generated when a request is started and passed to callbacks when the request is completed. - * E.g. `Unit` or a timestamp (`Long`). - */ - type TOKEN - - /** Invoked when the request has been received to obtain the per-request token, such as the starting timestamp. */ - def requestToken: TOKEN - - /** Invoked when the request has been received. */ - def requestReceived(request: ServerRequest, token: TOKEN): F[Unit] - - /** Invoked when there's a decode failure for an input of the endpoint and the interpreter, or other interceptors, haven't provided a - * response. - */ - def decodeFailureNotHandled(ctx: DecodeFailureContext, token: TOKEN): F[Unit] - - /** Invoked when there's a decode failure for an input of the endpoint and the interpreter, or other interceptors, provided a response. */ - def decodeFailureHandled(ctx: DecodeFailureContext, response: ServerResponse[_], token: TOKEN): F[Unit] - - /** Invoked when the security logic fails and returns an error. */ - def securityFailureHandled(ctx: SecurityFailureContext[F, _], response: ServerResponse[_], token: TOKEN): F[Unit] - - /** Invoked when all inputs of the request have been decoded successfully and the endpoint handles the request by providing a response, - * with the given status code. - */ - def requestHandled(ctx: DecodeSuccessContext[F, _, _, _], response: ServerResponse[_], token: TOKEN): F[Unit] - - /** Invoked when an exception has been thrown when running the server logic or handling decode failures. */ - def exception(ctx: ExceptionContext[_, _], ex: Throwable, token: TOKEN): F[Unit] - - /** Allows defining a list of endpoints which should not log requestHandled. Exceptions, decode failures and security failures will still - * be logged. - */ - def ignoreEndpoints: Set[AnyEndpoint] = Set.empty -} - -case class DefaultServerLog[F[_]]( - doLogWhenReceived: String => F[Unit], - doLogWhenHandled: (String, Option[Throwable]) => F[Unit], - doLogAllDecodeFailures: (String, Option[Throwable]) => F[Unit], - doLogExceptions: (String, Throwable) => F[Unit], - noLog: F[Unit], - logWhenReceived: Boolean = false, - logWhenHandled: Boolean = true, - logAllDecodeFailures: Boolean = false, - logLogicExceptions: Boolean = true, - showEndpoint: AnyEndpoint => String = _.showShort, - showRequest: ServerRequest => String = _.showShort, - showResponse: ServerResponse[_] => String = _.showShort, - includeTiming: Boolean = true, - clock: Clock = Clock.systemUTC(), - override val ignoreEndpoints: Set[AnyEndpoint] = Set.empty -) extends ServerLog[F] { - - def doLogWhenReceived(f: String => F[Unit]): DefaultServerLog[F] = copy(doLogWhenReceived = f) - def doLogWhenHandled(f: (String, Option[Throwable]) => F[Unit]): DefaultServerLog[F] = copy(doLogWhenHandled = f) - def doLogAllDecodeFailures(f: (String, Option[Throwable]) => F[Unit]): DefaultServerLog[F] = copy(doLogAllDecodeFailures = f) - def doLogExceptions(f: (String, Throwable) => F[Unit]): DefaultServerLog[F] = copy(doLogExceptions = f) - def noLog(f: F[Unit]): DefaultServerLog[F] = copy(noLog = f) - def logWhenHandled(doLog: Boolean): DefaultServerLog[F] = copy(logWhenHandled = doLog) - def logAllDecodeFailures(doLog: Boolean): DefaultServerLog[F] = copy(logAllDecodeFailures = doLog) - def logLogicExceptions(doLog: Boolean): DefaultServerLog[F] = copy(logLogicExceptions = doLog) - def showEndpoint(s: AnyEndpoint => String): DefaultServerLog[F] = copy(showEndpoint = s) - def showRequest(s: ServerRequest => String): DefaultServerLog[F] = copy(showRequest = s) - def showResponse(s: ServerResponse[_] => String): DefaultServerLog[F] = copy(showResponse = s) - def includeTiming(doInclude: Boolean): DefaultServerLog[F] = copy(includeTiming = doInclude) - def clock(c: Clock): DefaultServerLog[F] = copy(clock = c) - def ignoreEndpoints(es: Seq[AnyEndpoint]): DefaultServerLog[F] = copy(ignoreEndpoints = es.toSet) - - // - - override type TOKEN = Long - - override def requestToken: TOKEN = if (includeTiming) now() else 0 - - override def requestReceived(request: ServerRequest, token: TOKEN): F[Unit] = - if (logWhenReceived) doLogWhenReceived(s"Request received: ${showRequest(request)}") else noLog - - override def decodeFailureNotHandled(ctx: DecodeFailureContext, token: TOKEN): F[Unit] = - if (logAllDecodeFailures) - doLogAllDecodeFailures( - s"Request: ${showRequest(ctx.request)}, not handled by: ${showEndpoint(ctx.endpoint)}${took(token)}; decode failure: ${ctx.failure}, on input: ${ctx.failingInput.show}", - exception(ctx) - ) - else noLog - - override def decodeFailureHandled(ctx: DecodeFailureContext, response: ServerResponse[_], token: TOKEN): F[Unit] = - if (logWhenHandled) - doLogWhenHandled( - s"Request: ${showRequest(ctx.request)}, handled by: ${showEndpoint( - ctx.endpoint - )}${took(token)}; decode failure: ${ctx.failure}, on input: ${ctx.failingInput.show}; response: ${showResponse(response)}", - exception(ctx) - ) - else noLog - - override def securityFailureHandled(ctx: SecurityFailureContext[F, _], response: ServerResponse[_], token: TOKEN): F[Unit] = - if (logWhenHandled) - doLogWhenHandled( - s"Request: ${showRequest(ctx.request)}, handled by: ${showEndpoint(ctx.endpoint)}${took(token)}; security logic error response: ${showResponse(response)}", - None - ) - else noLog - - override def requestHandled(ctx: DecodeSuccessContext[F, _, _, _], response: ServerResponse[_], token: TOKEN): F[Unit] = - if (logWhenHandled) - doLogWhenHandled( - s"Request: ${showRequest(ctx.request)}, handled by: ${showEndpoint(ctx.endpoint)}${took(token)}; response: ${showResponse(response)}", - None - ) - else noLog - - override def exception(ctx: ExceptionContext[_, _], ex: Throwable, token: TOKEN): F[Unit] = - if (logLogicExceptions) - doLogExceptions(s"Exception when handling request: ${showRequest(ctx.request)}, by: ${showEndpoint(ctx.endpoint)}${took(token)}", ex) - else noLog - - private def now() = clock.instant().toEpochMilli - - private def took(token: TOKEN): String = if (includeTiming) s", took: ${now() - token}ms" else "" - - private def exception(ctx: DecodeFailureContext): Option[Throwable] = - ctx.failure match { - case DecodeResult.Error(_, error) => Some(error) - case _ => None - } -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/log/ServerLogInterceptor.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/log/ServerLogInterceptor.scala deleted file mode 100644 index 3536752e..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/log/ServerLogInterceptor.scala +++ /dev/null @@ -1,87 +0,0 @@ -package sttp.tapir.server.interceptor.log - -import sttp.monad.MonadError -import sttp.monad.syntax._ -import sttp.tapir.model.ServerRequest -import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.interceptor._ -import sttp.tapir.server.interpreter.BodyListener -import sttp.tapir.server.model.ServerResponse -import sttp.tapir.AnyEndpoint - -/** @tparam F The effect in which log messages are returned. */ -class ServerLogInterceptor[F[_]](serverLog: ServerLog[F]) extends RequestInterceptor[F] { - override def apply[R, B]( - responder: Responder[F, B], - requestHandler: EndpointInterceptor[F] => RequestHandler[F, R, B] - ): RequestHandler[F, R, B] = { - val token = serverLog.requestToken - val delegate = requestHandler(new ServerLogEndpointInterceptor[F, serverLog.TOKEN](serverLog, token)) - new RequestHandler[F, R, B] { - override def apply(request: ServerRequest, endpoints: List[ServerEndpoint[R, F]])(implicit - monad: MonadError[F] - ): F[RequestResult[B]] = { - serverLog.requestReceived(request, token).flatMap(_ => delegate(request, endpoints)) - } - } - } -} - -class ServerLogEndpointInterceptor[F[_], T](serverLog: ServerLog[F] { type TOKEN = T }, token: T) extends EndpointInterceptor[F] { - override def apply[B](responder: Responder[F, B], decodeHandler: EndpointHandler[F, B]): EndpointHandler[F, B] = - new EndpointHandler[F, B] { - override def onDecodeSuccess[A, U, I](ctx: DecodeSuccessContext[F, A, U, I])(implicit - monad: MonadError[F], - bodyListener: BodyListener[F, B] - ): F[ServerResponse[B]] = { - decodeHandler - .onDecodeSuccess(ctx) - .flatMap { response => - if (serverLog.ignoreEndpoints.contains(ctx.endpoint)) - response.unit - else - serverLog.requestHandled(ctx, response, token).map(_ => response) - } - .handleError { case e: Throwable => - serverLog - .exception(ExceptionContext(ctx.endpoint, Some(ctx.securityInput), Some(ctx.principal), ctx.request), e, token) - .flatMap(_ => monad.error(e)) - } - } - - override def onSecurityFailure[A]( - ctx: SecurityFailureContext[F, A] - )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] = { - decodeHandler - .onSecurityFailure(ctx) - .flatMap { response => - serverLog.securityFailureHandled(ctx, response, token).map(_ => response) - } - .handleError { case e: Throwable => - serverLog - .exception(ExceptionContext(ctx.endpoint, Some(ctx.securityInput), None, ctx.request), e, token) - .flatMap(_ => monad.error(e)) - } - } - - override def onDecodeFailure( - ctx: DecodeFailureContext - )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[Option[ServerResponse[B]]] = { - decodeHandler - .onDecodeFailure(ctx) - .flatMap { - case r @ None => - serverLog.decodeFailureNotHandled(ctx, token).map(_ => r: Option[ServerResponse[B]]) - case r @ Some(response) => - serverLog - .decodeFailureHandled(ctx, response, token) - .map(_ => r: Option[ServerResponse[B]]) - } - .handleError { case e: Throwable => - serverLog - .exception(ExceptionContext(ctx.endpoint, None, None, ctx.request), e, token) - .flatMap(_ => monad.error(e)) - } - } - } -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/metrics/MetricsEndpointInterceptor.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/metrics/MetricsEndpointInterceptor.scala deleted file mode 100644 index 57f8fc04..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/metrics/MetricsEndpointInterceptor.scala +++ /dev/null @@ -1,143 +0,0 @@ -package sttp.tapir.server.interceptor.metrics - -import sttp.monad.MonadError -import sttp.monad.syntax._ -import sttp.tapir.AnyEndpoint -import sttp.tapir.server.interceptor._ -import sttp.tapir.server.interpreter.BodyListener -import sttp.tapir.server.interpreter.BodyListener._ -import sttp.tapir.server.metrics.{EndpointMetric, Metric} -import sttp.tapir.server.model.ServerResponse - -import scala.util.{Failure, Success, Try} - -class MetricsRequestInterceptor[F[_]](metrics: List[Metric[F, _]], ignoreEndpoints: Seq[AnyEndpoint]) extends RequestInterceptor[F] { - - override def apply[R, B]( - responder: Responder[F, B], - requestHandler: EndpointInterceptor[F] => RequestHandler[F, R, B] - ): RequestHandler[F, R, B] = - RequestHandler.from { (request, endpoints, monad) => - implicit val m: MonadError[F] = monad - metrics - .foldLeft(List.empty[EndpointMetric[F]].unit) { (mAcc, metric) => - for { - metrics <- mAcc - endpointMetric <- metric match { - case Metric(m, onRequest) => onRequest(request, m, monad) - } - } yield endpointMetric :: metrics - } - .flatMap { endpointMetrics => - requestHandler(new MetricsEndpointInterceptor[F](endpointMetrics.reverse, ignoreEndpoints)).apply(request, endpoints) - } - } -} - -private[metrics] class MetricsEndpointInterceptor[F[_]]( - endpointMetrics: List[EndpointMetric[F]], - ignoreEndpoints: Seq[AnyEndpoint] -) extends EndpointInterceptor[F] { - - override def apply[B](responder: Responder[F, B], endpointHandler: EndpointHandler[F, B]): EndpointHandler[F, B] = - new EndpointHandler[F, B] { - - override def onDecodeSuccess[A, U, I]( - ctx: DecodeSuccessContext[F, A, U, I] - )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] = { - if (ignoreEndpoints.contains(ctx.endpoint)) endpointHandler.onDecodeSuccess(ctx) - else { - val responseWithMetrics: F[ServerResponse[B]] = for { - _ <- collectRequestMetrics(ctx.endpoint) - response <- endpointHandler.onDecodeSuccess(ctx) - _ <- collectResponseHeadersMetrics(ctx.endpoint, response) - withMetrics <- withBodyOnComplete(ctx.endpoint, response) - } yield withMetrics - - handleResponseExceptions(responseWithMetrics, ctx.endpoint) - } - } - - /** Collects `onResponse` as well as `onRequest` metric which was not collected in `onDecodeSuccess` stage. */ - override def onSecurityFailure[A]( - ctx: SecurityFailureContext[F, A] - )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] = { - if (ignoreEndpoints.contains(ctx.endpoint)) endpointHandler.onSecurityFailure(ctx) - else { - val responseWithMetrics: F[ServerResponse[B]] = for { - _ <- collectRequestMetrics(ctx.endpoint) - response <- endpointHandler.onSecurityFailure(ctx) - _ <- collectResponseHeadersMetrics(ctx.endpoint, response) - withMetrics <- withBodyOnComplete(ctx.endpoint, response) - } yield withMetrics - - handleResponseExceptions(responseWithMetrics, ctx.endpoint) - } - } - - /** If there's some `ServerResponse` collects `onResponse` as well as `onRequest` metric which was not collected in `onDecodeSuccess` - * stage. - */ - override def onDecodeFailure( - ctx: DecodeFailureContext - )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[Option[ServerResponse[B]]] = { - if (ignoreEndpoints.contains(ctx.endpoint)) endpointHandler.onDecodeFailure(ctx) - else { - val responseWithMetrics: F[Option[ServerResponse[B]]] = for { - response <- endpointHandler.onDecodeFailure(ctx) - withMetrics <- response match { - case Some(response) => - for { - _ <- collectRequestMetrics(ctx.endpoint) - _ <- collectResponseHeadersMetrics(ctx.endpoint, response) - res <- withBodyOnComplete(ctx.endpoint, response) - } yield Some(res) - case None => monad.unit(None) - } - } yield withMetrics - - handleResponseExceptions(responseWithMetrics, ctx.endpoint) - } - } - } - - private def collectMetrics(pf: PartialFunction[EndpointMetric[F], F[Unit]])(implicit monad: MonadError[F]): F[Unit] = { - def sequence(metrics: List[EndpointMetric[F]]): F[Unit] = { - metrics match { - case Nil => ().unit - case m :: tail if pf.isDefinedAt(m) => pf(m).flatMap(_ => sequence(tail)) - case _ :: tail => sequence(tail) - } - } - sequence(endpointMetrics) - } - - private def withBodyOnComplete[B](endpoint: AnyEndpoint, sr: ServerResponse[B])(implicit - monad: MonadError[F], - bodyListener: BodyListener[F, B] - ): F[ServerResponse[B]] = { - val cb: Try[Unit] => F[Unit] = { - case Success(_) => - collectMetrics { case EndpointMetric(_, _, Some(onResponseBody), _) => onResponseBody(endpoint, sr) } - case Failure(ex) => - collectExceptionMetrics(endpoint, ex) - } - - sr match { - case sr @ ServerResponse(_, _, Some(body), _) => body.onComplete(cb).map(b => sr.copy(body = Some(b))) - case sr @ ServerResponse(_, _, None, _) => cb(Success(())).map(_ => sr) - } - } - - private def handleResponseExceptions[T](r: F[T], e: AnyEndpoint)(implicit monad: MonadError[F]): F[T] = - r.handleError { case ex: Exception => collectExceptionMetrics(e, ex) } - - private def collectExceptionMetrics[T](e: AnyEndpoint, ex: Throwable)(implicit monad: MonadError[F]): F[T] = - collectMetrics { case EndpointMetric(_, _, _, Some(onException)) => onException(e, ex) }.flatMap(_ => monad.error(ex)) - - private def collectRequestMetrics(endpoint: AnyEndpoint)(implicit monad: MonadError[F]): F[Unit] = - collectMetrics { case EndpointMetric(Some(onRequest), _, _, _) => onRequest(endpoint) } - - private def collectResponseHeadersMetrics[B](endpoint: AnyEndpoint, sr: ServerResponse[B])(implicit monad: MonadError[F]): F[Unit] = - collectMetrics { case EndpointMetric(_, Some(onResponseHeaders), _, _) => onResponseHeaders(endpoint, sr) } -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/reject/RejectHandler.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/reject/RejectHandler.scala deleted file mode 100644 index 2e6b92ec..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/reject/RejectHandler.scala +++ /dev/null @@ -1,56 +0,0 @@ -package sttp.tapir.server.interceptor.reject - -import sttp.model.StatusCode -import sttp.monad.MonadError -import sttp.tapir._ -import sttp.tapir.server.interceptor.RequestResult -import sttp.tapir.server.model.ValuedEndpointOutput - -trait RejectHandler[F[_]] { - def apply(failure: RequestResult.Failure)(implicit monad: MonadError[F]): F[Option[ValuedEndpointOutput[_]]] -} - -object RejectHandler { - def apply[F[_]](f: RequestResult.Failure => F[Option[ValuedEndpointOutput[_]]]): RejectHandler[F] = new RejectHandler[F] { - override def apply(failure: RequestResult.Failure)(implicit monad: MonadError[F]): F[Option[ValuedEndpointOutput[_]]] = - f(failure) - } - - def pure[F[_]](f: RequestResult.Failure => Option[ValuedEndpointOutput[_]]): RejectHandler[F] = new RejectHandler[F] { - override def apply(failure: RequestResult.Failure)(implicit monad: MonadError[F]): F[Option[ValuedEndpointOutput[_]]] = - monad.unit(f(failure)) - } -} - -case class DefaultRejectHandler[F[_]]( - response: (StatusCode, String) => ValuedEndpointOutput[_], - defaultStatusCodeAndBody: Option[(StatusCode, String)] -) extends RejectHandler[F] { - override def apply(failure: RequestResult.Failure)(implicit monad: MonadError[F]): F[Option[ValuedEndpointOutput[_]]] = { - import DefaultRejectHandler._ - - val statusCodeAndBody = if (hasMethodMismatch(failure)) Some(Responses.MethodNotAllowed) else defaultStatusCodeAndBody - monad.unit(statusCodeAndBody.map(response.tupled)) - } -} - -object DefaultRejectHandler { - def apply[F[_]]: RejectHandler[F] = - DefaultRejectHandler[F]((sc: StatusCode, m: String) => ValuedEndpointOutput(statusCode.and(stringBody), (sc, m)), None) - - def orNotFound[F[_]]: RejectHandler[F] = - DefaultRejectHandler[F]( - (sc: StatusCode, m: String) => ValuedEndpointOutput(statusCode.and(stringBody), (sc, m)), - Some(Responses.NotFound) - ) - - private def hasMethodMismatch(f: RequestResult.Failure): Boolean = f.failures.map(_.failingInput).exists { - case _: EndpointInput.FixedMethod[_] => true - case _ => false - } - - object Responses { - val NotFound: (StatusCode, String) = (StatusCode.NotFound, "Not Found") - val MethodNotAllowed: (StatusCode, String) = (StatusCode.MethodNotAllowed, "Method Not Allowed") - } -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/reject/RejectInterceptor.scala b/tapir/core/src/main/scala/sttp/tapir/server/interceptor/reject/RejectInterceptor.scala deleted file mode 100644 index eed89ee5..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interceptor/reject/RejectInterceptor.scala +++ /dev/null @@ -1,50 +0,0 @@ -package sttp.tapir.server.interceptor.reject - -import sttp.monad.MonadError -import sttp.monad.syntax._ -import sttp.tapir.model.ServerRequest -import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.interceptor._ - -/** Specifies what should be done if decoding the request has failed for all endpoints, and multiple endpoints have been - * interpreted (doesn't do anything when interpreting a single endpoint). - * - * By default, if there's a method decode failure, this means that the path must have matched (as it's decoded first); - * then, returning a 405 (method not allowed). - * - * In other cases, not returning a response, assuming that the interpreter will return a "no match" to the server - * implementation. - */ -class RejectInterceptor[F[_]](handler: RejectHandler[F]) extends RequestInterceptor[F] { - override def apply[R, B]( - responder: Responder[F, B], - requestHandler: EndpointInterceptor[F] => RequestHandler[F, R, B] - ): RequestHandler[F, R, B] = { - val next = requestHandler(EndpointInterceptor.noop) - new RequestHandler[F, R, B] { - override def apply(request: ServerRequest, endpoints: List[ServerEndpoint[R, F]])(implicit - monad: MonadError[F] - ): F[RequestResult[B]] = - next(request, endpoints).flatMap { - case r: RequestResult.Response[B] => (r: RequestResult[B]).unit - case f: RequestResult.Failure => - handler(f).flatMap { - case Some(value) => responder(request, value).map(RequestResult.Response(_)) - case None => (f: RequestResult[B]).unit - } - } - } - } -} - -object RejectInterceptor { - - /** When interpreting a single endpoint, disabling the reject interceptor, as returning a method mismatch only makes - * sense when there are more endpoints - */ - def disableWhenSingleEndpoint[F[_]]( - interceptors: List[Interceptor[F]], - ses: List[ServerEndpoint[_, F]] - ): List[Interceptor[F]] = - if (ses.length > 1) interceptors else interceptors.filterNot(_.isInstanceOf[RejectInterceptor[F]]) -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interpreter/BodyListener.scala b/tapir/core/src/main/scala/sttp/tapir/server/interpreter/BodyListener.scala deleted file mode 100644 index f4addb2f..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interpreter/BodyListener.scala +++ /dev/null @@ -1,13 +0,0 @@ -package sttp.tapir.server.interpreter - -import scala.util.Try - -trait BodyListener[F[_], B] { - def onComplete(body: B)(cb: Try[Unit] => F[Unit]): F[B] -} - -object BodyListener { - implicit class BodyListenerOps[B](body: B) { - def onComplete[F[_]](cb: Try[Unit] => F[Unit])(implicit l: BodyListener[F, B]): F[B] = l.onComplete(body)(cb) - } -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interpreter/DecodeBasicInputs.scala b/tapir/core/src/main/scala/sttp/tapir/server/interpreter/DecodeBasicInputs.scala deleted file mode 100644 index 2ca506d7..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interpreter/DecodeBasicInputs.scala +++ /dev/null @@ -1,347 +0,0 @@ -package sttp.tapir.server.interpreter - -import sttp.model.headers.Cookie -import sttp.model.{ContentTypeRange, HeaderNames, MediaType, Method, QueryParams} -import sttp.tapir.internal._ -import sttp.tapir.model.ServerRequest -import sttp.tapir.{DecodeResult, EndpointIO, EndpointInput, StreamBodyIO, oneOfBody} - -import scala.annotation.tailrec - -sealed trait DecodeBasicInputsResult -object DecodeBasicInputsResult { - - /** @param basicInputsValues Values of basic inputs, in order as they are defined in the endpoint. */ - case class Values( - basicInputsValues: Vector[Any], - bodyInputWithIndex: Option[(Either[EndpointIO.OneOfBody[_, _], EndpointIO.StreamBodyWrapper[_, _]], Int)] - ) extends DecodeBasicInputsResult { - private def verifyNoBody(input: EndpointInput[_]): Unit = if (bodyInputWithIndex.isDefined) { - throw new IllegalStateException(s"Double body definition: $input") - } - def addBodyInput[O](input: EndpointIO.Body[_, O], bodyIndex: Int): Values = { - verifyNoBody(input) - copy(bodyInputWithIndex = Some((Left(oneOfBody(ContentTypeRange.AnyRange -> input)), bodyIndex))) - } - def addOneOfBodyInput(input: EndpointIO.OneOfBody[_, _], bodyIndex: Int): Values = { - verifyNoBody(input) - copy(bodyInputWithIndex = Some((Left(input), bodyIndex))) - } - def addStreamingBodyInput(input: EndpointIO.StreamBodyWrapper[_, _], bodyIndex: Int): Values = { - verifyNoBody(input) - copy(bodyInputWithIndex = Some((Right(input), bodyIndex))) - } - - /** Sets the value of the body input, once it is known, if a body input is defined. */ - def setBodyInputValue(v: Any): Values = bodyInputWithIndex match { - case Some((_, i)) => copy(basicInputsValues = basicInputsValues.updated(i, v)) - case None => this - } - - def setBasicInputValue(v: Any, i: Int): Values = copy(basicInputsValues = basicInputsValues.updated(i, v)) - } - case class Failure(input: EndpointInput.Basic[_], failure: DecodeResult.Failure) extends DecodeBasicInputsResult - - def higherPriorityFailure(l: DecodeBasicInputsResult, r: DecodeBasicInputsResult): Option[Failure] = (l, r) match { - case (f1: Failure, _: Values) => Some(f1) - case (_: Values, f2: Failure) => Some(f2) - case (f1: Failure, f2: Failure) => Some(if (basicInputSortIndex(f2.input) < basicInputSortIndex(f1.input)) f2 else f1) - case _ => None - } -} - -/** @param previousLastPathInput - * The last path input from decoding a previous segment of inputs (security inputs), if any. - */ -case class DecodeInputsContext(request: ServerRequest, pathSegments: List[String], previousLastPathInput: Option[EndpointInput.Basic[_]]) { - def method: Method = request.method - def nextPathSegment: (Option[String], DecodeInputsContext) = - pathSegments match { - case Nil => (None, this) - case h :: t => (Some(h), DecodeInputsContext(request, t, previousLastPathInput)) - } - def header(name: String): List[String] = request.headers(name).toList - def headers: Seq[(String, String)] = request.headers.map(h => (h.name, h.value)) - def queryParameter(name: String): Seq[String] = queryParameters.getMulti(name).getOrElse(Nil) - val queryParameters: QueryParams = request.queryParameters -} -object DecodeInputsContext { - def apply(request: ServerRequest): DecodeInputsContext = DecodeInputsContext(request, request.pathSegments, None) -} - -object DecodeBasicInputs { - case class IndexedBasicInput(input: EndpointInput.Basic[_], index: Int) - - /** Decodes values of all basic inputs defined by the given `input`, and returns a map from the input to the input's value. - * - * An exception is the body input, which is not decoded. This is because typically bodies can be only read once. That's why, all non-body - * inputs are used to decide if a request matches the endpoint, or not. If a body input is present, it is also returned as part of the - * result. - * - * In case any of the decoding fails, the failure is returned together with the failing input. - * - * @param ctx - * The context, in which to decode the input. Contains the original request and progress in decoding the path - * @param matchWholePath - * Should the whole path be matched - that is, if the input doesn't exhaust the path, should a failure be reported - */ - def apply( - input: EndpointInput[_], - ctx: DecodeInputsContext, - matchWholePath: Boolean = true - ): (DecodeBasicInputsResult, DecodeInputsContext) = { - // The first decoding failure is returned. - // We decode in the following order: method, path, query, headers (incl. cookies), request, status, body - // An exact-path check is done after path & method matching - - val basicInputs = input.asVectorOfBasicInputs().zipWithIndex.map { case (el, i) => IndexedBasicInput(el, i) } - - val methodInputs = basicInputs.filter(t => isRequestMethod(t.input)) - val pathInputs = basicInputs.filter(t => isPath(t.input)) - val otherInputs = basicInputs.filterNot(t => isRequestMethod(t.input) || isPath(t.input)).sortBy(t => basicInputSortIndex(t.input)) - - // we're using null as a placeholder for the future values. All except the body (which is determined by - // interpreter-specific code), should be filled by the end of this method. - compose( - matchOthers(methodInputs, _, _), - matchPath(pathInputs, _, _, matchWholePath), - matchOthers(otherInputs, _, _) - )(DecodeBasicInputsResult.Values(Vector.fill(basicInputs.size)(null), None), ctx) - } - - /** We're decoding paths differently than other inputs. We first map all path segments to their decoding results (not checking if this is - * a successful or failed decoding at this stage). This is collected as the `decodedPathInputs` value. - * - * Once this is done, we check if there are remaining path segments. If yes - the decoding fails with a `Mismatch`. - * - * Hence, a failure due to a mismatch in the number of segments takes **priority** over any potential failures in decoding the segments. - */ - private def matchPath( - pathInputs: Vector[IndexedBasicInput], - decodeValues: DecodeBasicInputsResult.Values, - ctx: DecodeInputsContext, - matchWholePath: Boolean - ): (DecodeBasicInputsResult, DecodeInputsContext) = { - def matchPathInnerUsingLast(last: EndpointInput.Basic[_]) = matchPathInner( - pathInputs = pathInputs, - ctx = ctx.copy(previousLastPathInput = Some(last)), - decodeValues = decodeValues, - decodedPathInputs = Vector.empty, - lastPathInput = last, - matchWholePath = matchWholePath - ) - - (pathInputs.initAndLast, ctx.previousLastPathInput) match { - case (None, None) => - // Match everything if no path input is specified - (decodeValues, ctx) - case (Some((_, last)), _) => - // There are more path inputs, match the path exactly and report errors against the last path input - matchPathInnerUsingLast(last.input) - case (_, Some(last)) => - // There are no more path inputs, but some path has already been matched; match the path exactly and report - // possible errors against the last known path input - matchPathInnerUsingLast(last) - } - } - - @tailrec - private def matchPathInner( - pathInputs: Vector[IndexedBasicInput], - ctx: DecodeInputsContext, - decodeValues: DecodeBasicInputsResult.Values, - decodedPathInputs: Vector[(IndexedBasicInput, DecodeResult[_])], - lastPathInput: EndpointInput.Basic[_], - matchWholePath: Boolean - ): (DecodeBasicInputsResult, DecodeInputsContext) = { - pathInputs.headAndTail match { - case Some((idxInput @ IndexedBasicInput(in, _), restInputs)) => - in match { - case EndpointInput.FixedPath(expectedSegment, codec, _) => - val (nextSegment, newCtx) = ctx.nextPathSegment - nextSegment match { - case Some(seg) => - if (seg == expectedSegment) { - val newDecodedPathInputs = decodedPathInputs :+ ((idxInput, codec.decode(seg))) - matchPathInner(restInputs, newCtx, decodeValues, newDecodedPathInputs, idxInput.input, matchWholePath) - } else { - val failure = DecodeBasicInputsResult.Failure(in, DecodeResult.Mismatch(expectedSegment, seg)) - (failure, newCtx) - } - case None => - if (expectedSegment.isEmpty) { - // FixedPath("") matches an empty path - val newDecodedPathInputs = decodedPathInputs :+ ((idxInput, codec.decode(""))) - matchPathInner(restInputs, newCtx, decodeValues, newDecodedPathInputs, idxInput.input, matchWholePath) - } else { - // shape path mismatch - input path too short - val failure = DecodeBasicInputsResult.Failure(in, DecodeResult.Missing) - (failure, newCtx) - } - } - case i: EndpointInput.PathCapture[_] => - val (nextSegment, newCtx) = ctx.nextPathSegment - nextSegment match { - case Some(seg) => - val newDecodedPathInputs = decodedPathInputs :+ ((idxInput, i.codec.decode(seg))) - matchPathInner(restInputs, newCtx, decodeValues, newDecodedPathInputs, idxInput.input, matchWholePath) - case None => - val failure = DecodeBasicInputsResult.Failure(in, DecodeResult.Missing) - (failure, newCtx) - } - case i: EndpointInput.PathsCapture[_] => - val (paths, newCtx) = collectRemainingPath(Vector.empty, ctx) - val newDecodedPathInputs = decodedPathInputs :+ ((idxInput, i.codec.decode(paths.toList))) - matchPathInner(restInputs, newCtx, decodeValues, newDecodedPathInputs, idxInput.input, matchWholePath) - case _ => - throw new IllegalStateException(s"Unexpected EndpointInput ${in.show} encountered. This is most likely a bug in the library") - } - case None => - val (extraSegmentOpt, newCtx) = ctx.nextPathSegment - extraSegmentOpt match { - case Some(_) if matchWholePath => - // shape path mismatch - input path too long; there are more segments in the request path than expected by - // that input. Reporting a failure on the last path input. - val failure = - DecodeBasicInputsResult.Failure(lastPathInput, DecodeResult.Multiple(collectRemainingPath(Vector.empty, ctx)._1)) - (failure, newCtx) - case _ => - (foldDecodedPathInputs(decodedPathInputs, decodeValues), ctx) - } - } - } - - @tailrec - private def foldDecodedPathInputs( - decodedPathInputs: Vector[(IndexedBasicInput, DecodeResult[_])], - acc: DecodeBasicInputsResult.Values - ): DecodeBasicInputsResult = { - decodedPathInputs.headAndTail match { - case None => acc - case Some((t, ts)) => - t match { - case (indexedInput, failure: DecodeResult.Failure) => DecodeBasicInputsResult.Failure(indexedInput.input, failure) - case (indexedInput, DecodeResult.Value(v)) => foldDecodedPathInputs(ts, acc.setBasicInputValue(v, indexedInput.index)) - } - } - } - - @tailrec - private def collectRemainingPath(acc: Vector[String], c: DecodeInputsContext): (Vector[String], DecodeInputsContext) = - c.nextPathSegment match { - case (Some(s), c2) => collectRemainingPath(acc :+ s, c2) - case (None, c2) => (acc, c2) - } - - @tailrec - private def matchOthers( - inputs: Vector[IndexedBasicInput], - values: DecodeBasicInputsResult.Values, - ctx: DecodeInputsContext - ): (DecodeBasicInputsResult, DecodeInputsContext) = { - inputs.headAndTail match { - case None => (values, ctx) - case Some((IndexedBasicInput(input @ EndpointIO.Body(_, _, _), index), inputsTail)) => - matchOthers(inputsTail, values.addBodyInput(input, index), ctx) - case Some((IndexedBasicInput(input @ EndpointIO.OneOfBody(_, _), index), inputsTail)) => - matchOthers(inputsTail, values.addOneOfBodyInput(input, index), ctx) - case Some((IndexedBasicInput(input @ EndpointIO.StreamBodyWrapper(StreamBodyIO(_, _, _, _, _)), index), inputsTail)) => - matchOthers(inputsTail, values.addStreamingBodyInput(input, index), ctx) - case Some((indexedInput, inputsTail)) => - val (result, ctx2) = matchOther(indexedInput.input, ctx) - result match { - case DecodeResult.Value(v) => matchOthers(inputsTail, values.setBasicInputValue(v, indexedInput.index), ctx2) - case failure: DecodeResult.Failure => (DecodeBasicInputsResult.Failure(indexedInput.input, failure), ctx2) - } - } - } - - private def matchOther(input: EndpointInput.Basic[_], ctx: DecodeInputsContext): (DecodeResult[_], DecodeInputsContext) = { - input match { - case EndpointInput.FixedMethod(m, codec, _) => - if (m == ctx.method) (codec.decode(()), ctx) - else (DecodeResult.Mismatch(m.method, ctx.method.method), ctx) - - case EndpointIO.FixedHeader(h @ sttp.model.Header(n, v), codec, _) => - if (ctx.header(n) == Nil) (DecodeResult.Missing, ctx) - else if (List(v) == ctx.header(n)) (codec.decode(()), ctx) - else if (h.is(HeaderNames.ContentType)) { - // do not compare Content-Type 'boundary' directive - val inMedia = MediaType.parse(ctx.header(n).head).map(_.copy(otherParameters = Map.empty)) - val reqMedia = MediaType.parse(v).map(_.copy(otherParameters = Map.empty)) - if (inMedia == reqMedia) (codec.decode(()), ctx) - else (DecodeResult.Mismatch(reqMedia.toString, inMedia.toString), ctx) - } else (DecodeResult.Mismatch(List(v).mkString, ctx.header(n).mkString), ctx) - - case EndpointInput.Query(name, None, codec, _) => - (codec.decode(ctx.queryParameter(name).toList), ctx) - - case EndpointInput.Query(name, Some(flagValue), codec, _) => - ctx.queryParameters.getMulti(name) match { - case Some(Seq()) | Some(Seq("")) => (DecodeResult.Value(flagValue), ctx) - case values => (codec.decode(values.getOrElse(Nil).toList), ctx) - } - - case EndpointInput.QueryParams(codec, _) => - (codec.decode(ctx.queryParameters), ctx) - - case EndpointInput.Cookie(name, codec, _) => - val allCookies = DecodeResult - .sequence( - ctx.headers - .filter(_._1.equalsIgnoreCase(HeaderNames.Cookie)) - .map(p => - Cookie.parse(p._2) match { - case Left(e) => DecodeResult.Error(p._2, new RuntimeException(e)) - case Right(c) => DecodeResult.Value(c) - } - ) - ) - .map(_.flatten) - val decodedCookieValue = allCookies.map(_.find(_.name == name).map(_.value)).flatMap(codec.decode) - (decodedCookieValue, ctx) - - case EndpointIO.Header(name, codec, _) => - (codec.decode(ctx.header(name)), ctx) - - case EndpointIO.Headers(codec, _) => - (codec.decode(ctx.headers.map((sttp.model.Header.apply _).tupled).toList), ctx) - - case EndpointInput.ExtractFromRequest(codec, _) => - (codec.decode(ctx.request), ctx) - - case EndpointIO.Empty(codec, _) => - (codec.decode(()), ctx) - - case input => - throw new IllegalStateException( - s"Unexpected EndpointInput ${input.show} encountered. This is most likely a bug in the library" - ) - } - } - - private val isRequestMethod: EndpointInput.Basic[_] => Boolean = { - case _: EndpointInput.FixedMethod[_] => true - case _ => false - } - - private val isPath: EndpointInput.Basic[_] => Boolean = { - case _: EndpointInput.FixedPath[_] => true - case _: EndpointInput.PathCapture[_] => true - case _: EndpointInput.PathsCapture[_] => true - case _ => false - } - - private type DecodeInputResultTransform = - (DecodeBasicInputsResult.Values, DecodeInputsContext) => (DecodeBasicInputsResult, DecodeInputsContext) - private def compose(fs: DecodeInputResultTransform*): DecodeInputResultTransform = { (values, ctx) => - fs match { - case f +: tail => - f(values, ctx) match { - case (values2: DecodeBasicInputsResult.Values, ctx2) => compose(tail: _*)(values2, ctx2) - case r => r - } - case _ => (values, ctx) - } - } -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interpreter/EncodeOutputs.scala b/tapir/core/src/main/scala/sttp/tapir/server/interpreter/EncodeOutputs.scala deleted file mode 100644 index c5020603..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interpreter/EncodeOutputs.scala +++ /dev/null @@ -1,154 +0,0 @@ -package sttp.tapir.server.interpreter - -import sttp.model._ -import sttp.tapir.EndpointIO.OneOfBodyVariant -import sttp.tapir.EndpointOutput.OneOfVariant -import sttp.tapir.internal.{Params, ParamsAsAny, SplitParams, _} -import sttp.tapir.{Codec, CodecFormat, EndpointIO, EndpointOutput, Mapping, StreamBodyIO, WebSocketBodyOutput} - -import java.nio.charset.Charset -import scala.collection.immutable.Seq - -class EncodeOutputs[B, S](rawToResponseBody: ToResponseBody[B, S], acceptsContentTypes: Seq[ContentTypeRange]) { - def apply(output: EndpointOutput[_], value: Params, ov: OutputValues[B]): OutputValues[B] = { - output match { - case s: EndpointIO.Single[_] => applySingle(s, value, ov) - case s: EndpointOutput.Single[_] => applySingle(s, value, ov) - case EndpointIO.Pair(left, right, _, split) => applyPair(left, right, split, value, ov) - case EndpointOutput.Pair(left, right, _, split) => applyPair(left, right, split, value, ov) - case EndpointOutput.Void() => throw new IllegalArgumentException("Cannot encode a void output!") - } - } - - private def applyPair( - left: EndpointOutput[_], - right: EndpointOutput[_], - split: SplitParams, - params: Params, - ov: OutputValues[B] - ): OutputValues[B] = { - val (leftParams, rightParams) = split(params) - apply(right, rightParams, apply(left, leftParams, ov)) - } - - private def applySingle(output: EndpointOutput.Single[_], value: Params, ov: OutputValues[B]): OutputValues[B] = { - def encodedC[T](codec: Codec[_, _, _ <: CodecFormat]): T = codec.asInstanceOf[Codec[T, Any, CodecFormat]].encode(value.asAny) - def encodedM[T](mapping: Mapping[_, _]): T = mapping.asInstanceOf[Mapping[T, Any]].encode(value.asAny) - output match { - case EndpointIO.Empty(_, _) => ov - case EndpointOutput.FixedStatusCode(sc, _, _) => ov.withStatusCode(sc) - case EndpointIO.FixedHeader(header, _, _) => ov.withHeader(header.name, header.value) - case EndpointIO.Body(rawBodyType, codec, _) => - val maybeCharset = if (codec.format.mediaType.isText) charset(rawBodyType) else None - ov.withBody(headers => rawToResponseBody.fromRawValue(encodedC(codec), headers, codec.format, rawBodyType)) - .withDefaultContentType(codec.format, maybeCharset) - case EndpointIO.OneOfBody(variants, mapping) => applySingle(chooseOneOfVariant(variants), ParamsAsAny(encodedM[Any](mapping)), ov) - case EndpointIO.StreamBodyWrapper(StreamBodyIO(_, codec, _, charset, _)) => - ov.withBody(headers => rawToResponseBody.fromStreamValue(encodedC(codec), headers, codec.format, charset)) - .withDefaultContentType(codec.format, charset) - .withHeaderTransformation(hs => - if (hs.exists(_.is(HeaderNames.ContentLength))) hs else hs :+ Header(HeaderNames.TransferEncoding, "chunked") - ) - case EndpointIO.Header(name, codec, _) => - encodedC[List[String]](codec).foldLeft(ov) { case (ovv, headerValue) => ovv.withHeader(name, headerValue) } - case EndpointIO.Headers(codec, _) => - encodedC[List[sttp.model.Header]](codec).foldLeft(ov)((ov2, h) => ov2.withHeader(h.name, h.value)) - case EndpointIO.MappedPair(wrapped, mapping) => apply(wrapped, ParamsAsAny(encodedM[Any](mapping)), ov) - case EndpointOutput.StatusCode(_, codec, _) => ov.withStatusCode(encodedC[StatusCode](codec)) - case EndpointOutput.WebSocketBodyWrapper(o) => - ov.withBody(_ => - rawToResponseBody.fromWebSocketPipe( - encodedC[rawToResponseBody.streams.Pipe[Any, Any]](o.codec), - o.asInstanceOf[WebSocketBodyOutput[rawToResponseBody.streams.Pipe[Any, Any], Any, Any, Any, S]] - ) - ) - case o @ EndpointOutput.OneOf(mappings, mapping) => - val enc = encodedM[Any](mapping) - val applicableMappings = mappings.filter(_.appliesTo(enc)) - if (applicableMappings.isEmpty) { - throw new IllegalArgumentException( - s"None of the mappings defined in the one-of output: ${o.show}, is applicable to the value: $enc. " + - s"Verify that the type parameters to oneOf are correct, and that the oneOfVariants are exhaustive " + - s"(that is, that they cover all possible cases)." - ) - } - - val chosenVariant = chooseOneOfVariant(applicableMappings) - apply(chosenVariant.output, ParamsAsAny(enc), ov) - - case EndpointOutput.MappedPair(wrapped, mapping) => apply(wrapped, ParamsAsAny(encodedM[Any](mapping)), ov) - } - } - - private def chooseOneOfVariant(variants: List[OneOfBodyVariant[_]]): EndpointIO.Atom[_] = { - val mediaTypeToBody = variants.map(v => v.mediaTypeWithCharset -> v) - chooseBestVariant[OneOfBodyVariant[_]](mediaTypeToBody).getOrElse(variants.head).bodyAsAtom - } - - private def chooseOneOfVariant(variants: Seq[OneOfVariant[_]]): OneOfVariant[_] = { - // #1164: there might be multiple applicable mappings, for the same content type - e.g. when there's a default - // mapping. We need to take the first defined into account. - val bodyVariants: Seq[(MediaType, OneOfVariant[_])] = variants - .flatMap { om => - val mediaTypeFromBody = om.output.traverseOutputs { - case b: EndpointIO.Body[_, _] => Vector[(MediaType, OneOfVariant[_])](b.mediaTypeWithCharset -> om) - case b: EndpointIO.StreamBodyWrapper[_, _] => Vector[(MediaType, OneOfVariant[_])](b.mediaTypeWithCharset -> om) - } - - // #2200: some variants might have no body, which means that they match any of the `acceptsContentTypes`; - // in this case, creating a "fake" media type which will match the first content range - if (mediaTypeFromBody.isEmpty) { - val fakeMediaType = acceptsContentTypes.headOption - .map(r => MediaType(r.mainType, r.subType)) - .getOrElse(MediaType.ApplicationOctetStream) - Vector(fakeMediaType -> om) - } else mediaTypeFromBody - } - - chooseBestVariant(bodyVariants).getOrElse(variants.head) - } - - private def chooseBestVariant[T](variants: Seq[(MediaType, T)]): Option[T] = { - if (variants.nonEmpty) { - val mediaTypes = variants.map(_._1) - MediaType - .bestMatch(mediaTypes, acceptsContentTypes) - .flatMap(mt => variants.find(_._1 == mt).map(_._2)) - } else None - } -} - -case class OutputValues[B]( - body: Option[HasHeaders => B], - baseHeaders: Vector[Header], - headerTransformations: Vector[Vector[Header] => Vector[Header]], - statusCode: Option[StatusCode] -) { - def withBody(b: HasHeaders => B): OutputValues[B] = { - if (body.isDefined) { - throw new IllegalArgumentException("Body is already defined") - } - - copy(body = Some(b)) - } - - def withHeaderTransformation(t: Vector[Header] => Vector[Header]): OutputValues[B] = - copy(headerTransformations = headerTransformations :+ t) - def withDefaultContentType(format: CodecFormat, charset: Option[Charset]): OutputValues[B] = { - withHeaderTransformation { hs => - if (hs.exists(_.is(HeaderNames.ContentType))) hs - else hs :+ Header(HeaderNames.ContentType, charset.fold(format.mediaType)(format.mediaType.charset(_)).toString()) - } - } - - def withHeader(n: String, v: String): OutputValues[B] = copy(baseHeaders = baseHeaders :+ Header(n, v)) - - def withStatusCode(sc: StatusCode): OutputValues[B] = copy(statusCode = Some(sc)) - - def headers: Seq[Header] = { - headerTransformations.foldLeft(baseHeaders) { case (hs, t) => t(hs) } - } -} -object OutputValues { - def empty[B]: OutputValues[B] = OutputValues[B](None, Vector.empty, Vector.empty, None) -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interpreter/FilterServerEndpoints.scala b/tapir/core/src/main/scala/sttp/tapir/server/interpreter/FilterServerEndpoints.scala deleted file mode 100644 index 465ccd97..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interpreter/FilterServerEndpoints.scala +++ /dev/null @@ -1,114 +0,0 @@ -package sttp.tapir.server.interpreter - -import sttp.tapir.{AnyEndpoint, EndpointInput} -import sttp.tapir.internal.RichEndpointInput -import sttp.tapir.model.ServerRequest -import sttp.tapir.server.ServerEndpoint - -class FilterServerEndpoints[R, F[_]](rootLayer: PathLayer[R, F]) extends (ServerRequest => List[ServerEndpoint[R, F]]) { - - /** Given a request, returns the list of server endpoints which might potentially decode successfully, taking into account the path of the - * request. - */ - def apply(request: ServerRequest): List[ServerEndpoint[R, F]] = - request.pathSegments.foldLeft(rootLayer) { case (layer, segment) => layer.next(segment) }.endpoints -} - -object FilterServerEndpoints { - private sealed trait PathSegment - private case class Exact(s: String) extends PathSegment - private case object AnySingle extends PathSegment // ? - private case object AnyMulti extends PathSegment // * - - private def segmentsForEndpoint(e: AnyEndpoint): List[PathSegment] = { - val segments = e.securityInput - .and(e.input) - .asVectorOfBasicInputs() - .collect { - case EndpointInput.FixedPath(s, _, _) => Exact(s) - case _: EndpointInput.PathCapture[_] => AnySingle - case _: EndpointInput.PathsCapture[_] => AnyMulti - } - .toList - - // no path segments mean that the endpoint matches any path - if (segments.isEmpty) List(AnyMulti) else segments - } - - private def createLayer[R, F[_]](segmentsToEndpoints: List[(List[PathSegment], ServerEndpoint[R, F])]): PathLayer[R, F] = { - // first computing the distinct segments with which endpoints at this layer start - val distinctSegments: Set[PathSegment] = segmentsToEndpoints.flatMap(_._1.headOption).toSet - - val exactSegmentToNextLayer: Map[String, PathLayer[R, F]] = distinctSegments - // for each exact segment, creating the next layer - .collect { case e: Exact => e } - .map { exactSegment => - // computing the endpoints for the next layer: these are the endpoints which start with the given exact segment, - // a single or multi wildcard; peeling off this initial segment, unless it is a multi wildcard (as it can capture - // any number of path segments) - val peeledSegmentsToEndpoints = segmentsToEndpoints.flatMap { case (segments, se) => - segments match { - case head :: tail if head == exactSegment || head == AnySingle => List(tail -> se) - case head :: _ if head == AnyMulti => List(segments -> se) - case _ => Nil - } - } - - exactSegment.s -> createLayer(peeledSegmentsToEndpoints) - } - .toMap - - val wildcardSegmentNextLayer = { - // to avoid an infinite loop, if we are only left with multi wildcard path segments (capturing any number of path segments), - // returning a "terminal" layer with these endpoints, which will match any path - if (segmentsToEndpoints.forall { case (segments, _) => segments.headOption.contains(AnyMulti) }) { - new PathLayer[R, F] { - override val endpoints: List[ServerEndpoint[R, F]] = segmentsToEndpoints.map(_._2) - override def next(pathSegment: String): PathLayer[R, F] = this - } - } else { - // removing the head single wildcard path segments, and creating the next layer - val peeledSegmentsToEndpoints = segmentsToEndpoints.flatMap { case (segments, se) => - segments match { - case head :: tail if head == AnySingle => List(tail -> se) - case head :: _ if head == AnyMulti => List(segments -> se) - case _ => Nil - } - } - - createLayer(peeledSegmentsToEndpoints) - } - } - - new PathLayer[R, F] { - // endpoints at this layer are all for which there are no more path segments, or if the only path segment is *; - // an empty exact segment is a special-case used to denote root paths - override val endpoints: List[ServerEndpoint[R, F]] = segmentsToEndpoints - .filter { case (segments, _) => - segments.isEmpty || segments.headOption.contains(AnyMulti) || segments.headOption.contains(Exact("")) - } - .map(_._2) - - override def next(pathSegment: String): PathLayer[R, F] = { - // if there's at least one endpoint with an exact segment as the given one, looking up its layer (which also includes wildcard segments) - // otherwise, returning the layer which handles endpoints with wildcard segments - exactSegmentToNextLayer.getOrElse(pathSegment, wildcardSegmentNextLayer) - } - } - } - - def apply[R, F[_]](serverEndpoints: List[ServerEndpoint[R, F]]): FilterServerEndpoints[R, F] = { - val segmentsToEndpoints: List[(List[PathSegment], ServerEndpoint[R, F])] = - serverEndpoints.map(se => segmentsForEndpoint(se.endpoint) -> se) - - new FilterServerEndpoints[R, F](createLayer(segmentsToEndpoints)) - } -} - -private trait PathLayer[R, F[_]] { - - /** Endpoints at this layer, if the path is finished */ - def endpoints: List[ServerEndpoint[R, F]] - - def next(pathSegment: String): PathLayer[R, F] -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interpreter/InputValue.scala b/tapir/core/src/main/scala/sttp/tapir/server/interpreter/InputValue.scala deleted file mode 100644 index 992ec73b..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interpreter/InputValue.scala +++ /dev/null @@ -1,67 +0,0 @@ -package sttp.tapir.server.interpreter - -import sttp.tapir.internal.{CombineParams, Params, ParamsAsAny, RichVector} -import sttp.tapir.{DecodeResult, EndpointIO, EndpointInput, Mapping} - -sealed trait InputValueResult -object InputValueResult { - case class Value(params: Params, remainingBasicValues: Vector[Any]) extends InputValueResult - case class Failure(input: EndpointInput[_], failure: DecodeResult.Failure) extends InputValueResult -} - -object InputValue { - - /** Returns the value of the input, tupled and mapped as described by the data structure. Values of basic inputs are taken as consecutive - * values from `values.basicInputsValues`. Hence, these should match (in order). - */ - def apply(input: EndpointInput[_], values: DecodeBasicInputsResult.Values): InputValueResult = - apply(input, values.basicInputsValues) - - private def apply(input: EndpointInput[_], remainingBasicValues: Vector[Any]): InputValueResult = { - input match { - case EndpointInput.Pair(left, right, combine, _) => handlePair(left, right, combine, remainingBasicValues) - case EndpointIO.Pair(left, right, combine, _) => handlePair(left, right, combine, remainingBasicValues) - case EndpointInput.MappedPair(wrapped, codec) => handleMappedPair(wrapped, codec, remainingBasicValues) - case EndpointIO.MappedPair(wrapped, codec) => handleMappedPair(wrapped, codec, remainingBasicValues) - case auth: EndpointInput.Auth[_, _] => apply(auth.input, remainingBasicValues) - case _: EndpointInput.Basic[_] => - remainingBasicValues.headAndTail match { - case Some((v, valuesTail)) => InputValueResult.Value(ParamsAsAny(v), valuesTail) - case None => - throw new IllegalStateException(s"Mismatch between basic input values: $remainingBasicValues, and basic inputs in: $input") - } - } - } - - private def handlePair( - left: EndpointInput[_], - right: EndpointInput[_], - combine: CombineParams, - remainingBasicValues: Vector[Any] - ): InputValueResult = { - apply(left, remainingBasicValues) match { - case InputValueResult.Value(leftParams, remainingBasicValues2) => - apply(right, remainingBasicValues2) match { - case InputValueResult.Value(rightParams, remainingBasicValues3) => - InputValueResult.Value(combine(leftParams, rightParams), remainingBasicValues3) - case f2: InputValueResult.Failure => f2 - } - case f: InputValueResult.Failure => f - } - } - - private def handleMappedPair[II, T]( - wrapped: EndpointInput[II], - codec: Mapping[II, T], - remainingBasicValues: Vector[Any] - ): InputValueResult = { - apply(wrapped, remainingBasicValues) match { - case InputValueResult.Value(pairValue, remainingBasicValues2) => - codec.decode(pairValue.asAny.asInstanceOf[II]) match { - case DecodeResult.Value(v) => InputValueResult.Value(ParamsAsAny(v), remainingBasicValues2) - case f: DecodeResult.Failure => InputValueResult.Failure(wrapped, f) - } - case f: InputValueResult.Failure => f - } - } -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interpreter/RequestBody.scala b/tapir/core/src/main/scala/sttp/tapir/server/interpreter/RequestBody.scala deleted file mode 100644 index 2310c431..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interpreter/RequestBody.scala +++ /dev/null @@ -1,19 +0,0 @@ -package sttp.tapir.server.interpreter - -import sttp.capabilities.Streams -import sttp.model.Part -import sttp.tapir.model.ServerRequest -import sttp.tapir.{FileRange, RawBodyType, RawPart} - -trait RequestBody[F[_], S] { - val streams: Streams[S] - def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]] - def toStream(serverRequest: ServerRequest): streams.BinaryStream -} - -case class RawValue[R](value: R, createdFiles: Seq[FileRange] = Nil) - -object RawValue { - def fromParts(parts: Seq[RawPart]): RawValue[Seq[RawPart]] = - RawValue(parts, parts collect { case _ @Part(_, f: FileRange, _, _) => f }) -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala b/tapir/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala deleted file mode 100644 index ba9b0a05..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala +++ /dev/null @@ -1,274 +0,0 @@ -package sttp.tapir.server.interpreter - -import sttp.model.{Headers, StatusCode} -import sttp.monad.MonadError -import sttp.monad.syntax._ -import sttp.tapir.internal.{Params, ParamsAsAny, RichOneOfBody} -import sttp.tapir.model.ServerRequest -import sttp.tapir.server.{model, _} -import sttp.tapir.server.interceptor._ -import sttp.tapir.server.model.{ServerResponse, ValuedEndpointOutput} -import sttp.tapir.{DecodeResult, EndpointIO, EndpointInput, TapirFile} - -class ServerInterpreter[R, F[_], B, S]( - serverEndpoints: ServerRequest => List[ServerEndpoint[R, F]], - requestBody: RequestBody[F, S], - toResponseBody: ToResponseBody[B, S], - interceptors: List[Interceptor[F]], - deleteFile: TapirFile => F[Unit] -)(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]) { - def apply(request: ServerRequest): F[RequestResult[B]] = monad.suspend { - callInterceptors(interceptors, Nil, responder(defaultSuccessStatusCode)).apply(request, serverEndpoints(request)) - } - - /** Accumulates endpoint interceptors and calls `next` with the potentially transformed request. */ - private def callInterceptors( - is: List[Interceptor[F]], - eisAcc: List[EndpointInterceptor[F]], - responder: Responder[F, B] - ): RequestHandler[F, R, B] = { - is match { - case Nil => RequestHandler.from { (request, ses, _) => firstNotNone(request, ses, eisAcc.reverse, Nil) } - case (i: RequestInterceptor[F]) :: tail => - i( - responder, - { ei => RequestHandler.from { (request, ses, _) => callInterceptors(tail, ei :: eisAcc, responder).apply(request, ses) } } - ) - case (ei: EndpointInterceptor[F]) :: tail => callInterceptors(tail, ei :: eisAcc, responder) - } - } - - /** Try decoding subsequent server endpoints, until a non-None response is returned. */ - private def firstNotNone( - request: ServerRequest, - ses: List[ServerEndpoint[R, F]], - endpointInterceptors: List[EndpointInterceptor[F]], - accumulatedFailureContexts: List[DecodeFailureContext] - ): F[RequestResult[B]] = - ses match { - case Nil => (RequestResult.Failure(accumulatedFailureContexts.reverse): RequestResult[B]).unit - case se :: tail => - tryServerEndpoint[se.SECURITY_INPUT, se.PRINCIPAL, se.INPUT, se.ERROR_OUTPUT, se.OUTPUT]( - request, - se, - endpointInterceptors - ) - .flatMap { - case RequestResult.Failure(failureContexts) => - firstNotNone(request, tail, endpointInterceptors, failureContexts ++: accumulatedFailureContexts) - case r => r.unit - } - } - - private def tryServerEndpoint[A, U, I, E, O]( - request: ServerRequest, - se: ServerEndpoint.Full[A, U, I, E, O, R, F], - endpointInterceptors: List[EndpointInterceptor[F]] - ): F[RequestResult[B]] = { - val defaultSecurityFailureResponse = - ServerResponse[B](StatusCode.InternalServerError, Nil, None, None).unit - - def endpointHandler(securityFailureResponse: => F[ServerResponse[B]]): EndpointHandler[F, B] = - endpointInterceptors.foldRight(defaultEndpointHandler(securityFailureResponse)) { case (interceptor, handler) => - interceptor(responder(defaultSuccessStatusCode), handler) - } - - def resultOrValueFrom = new ResultOrValueFrom { - def onDecodeFailure(input: EndpointInput[_], failure: DecodeResult.Failure): F[RequestResult[B]] = { - val decodeFailureContext = interceptor.DecodeFailureContext(se.endpoint, input, failure, request) - endpointHandler(defaultSecurityFailureResponse) - .onDecodeFailure(decodeFailureContext) - .map { - case Some(response) => RequestResult.Response(response) - case None => RequestResult.Failure(List(decodeFailureContext)) - } - } - } - - // 1. decoding both security & regular basic inputs - note that this does *not* include decoding the body - val decodeBasicContext1 = DecodeInputsContext(request) - // the security input doesn't have to match the whole path, a prefix is fine - val (securityBasicInputs, decodeBasicContext2) = - DecodeBasicInputs(se.endpoint.securityInput, decodeBasicContext1, matchWholePath = false) - // the regular input is required to match the whole remaining path; otherwise a decode failure is reported - // to keep the progress in path matching, we are using the context returned by decoding the security input - val (regularBasicInputs, _) = DecodeBasicInputs(se.endpoint.input, decodeBasicContext2) - (for { - // 2. if the decoding failed, short-circuiting further processing with the decode failure that has a lower sort - // index (so that the correct one is passed to the decode failure handler) - _ <- resultOrValueFrom(DecodeBasicInputsResult.higherPriorityFailure(securityBasicInputs, regularBasicInputs)) - // 3. computing the security input value - securityValues <- resultOrValueFrom(decodeBody(request, securityBasicInputs)) - securityParams <- resultOrValueFrom(InputValue(se.endpoint.securityInput, securityValues)) - inputValues <- resultOrValueFrom(regularBasicInputs) - a = securityParams.asAny.asInstanceOf[A] - // 4. running the security logic - securityLogicResult <- ResultOrValue( - se.securityLogic(monad)(a).map(Right(_): Either[RequestResult[B], Either[E, U]]).handleError { case t: Throwable => - endpointHandler(monad.error(t)) - .onSecurityFailure(SecurityFailureContext(se, a, request)) - .map(r => Left(RequestResult.Response(r)): Either[RequestResult[B], Either[E, U]]) - } - ) - response <- securityLogicResult match { - case Left(e) => - resultOrValueFrom.value( - endpointHandler(responder(defaultErrorStatusCode)(request, model.ValuedEndpointOutput(se.endpoint.errorOutput, e))) - .onSecurityFailure(SecurityFailureContext(se, a, request)) - .map(r => RequestResult.Response(r): RequestResult[B]) - ) - - case Right(u) => - for { - // 5. decoding the body of regular inputs, computing the input value, and running the main logic - values <- resultOrValueFrom(decodeBody(request, inputValues)) - params <- resultOrValueFrom(InputValue(se.endpoint.input, values)) - response <- resultOrValueFrom.value( - endpointHandler(defaultSecurityFailureResponse) - .onDecodeSuccess(interceptor.DecodeSuccessContext(se, a, u, params.asAny.asInstanceOf[I], request)) - .map(r => RequestResult.Response(r): RequestResult[B]) - ) - } yield response - } - } yield response).fold - } - - private def decodeBody( - request: ServerRequest, - result: DecodeBasicInputsResult - ): F[DecodeBasicInputsResult] = - result match { - case values: DecodeBasicInputsResult.Values => - values.bodyInputWithIndex match { - case Some((Left(oneOfBodyInput), _)) => - oneOfBodyInput.chooseBodyToDecode(request.contentTypeParsed) match { - case Some(Left(body)) => decodeBody(request, values, body) - case Some(Right(body: EndpointIO.StreamBodyWrapper[Any, Any])) => decodeStreamingBody(request, values, body) - case None => unsupportedInputMediaTypeResponse(request, oneOfBodyInput) - } - case Some((Right(bodyInput: EndpointIO.StreamBodyWrapper[Any, Any]), _)) => decodeStreamingBody(request, values, bodyInput) - case None => (values: DecodeBasicInputsResult).unit - } - case failure: DecodeBasicInputsResult.Failure => (failure: DecodeBasicInputsResult).unit - } - - private def decodeStreamingBody( - request: ServerRequest, - values: DecodeBasicInputsResult.Values, - bodyInput: EndpointIO.StreamBodyWrapper[Any, Any] - ): F[DecodeBasicInputsResult] = - (bodyInput.codec.decode(requestBody.toStream(request)) match { - case DecodeResult.Value(bodyV) => values.setBodyInputValue(bodyV) - case failure: DecodeResult.Failure => DecodeBasicInputsResult.Failure(bodyInput, failure): DecodeBasicInputsResult - }).unit - - private def decodeBody[RAW, T]( - request: ServerRequest, - values: DecodeBasicInputsResult.Values, - bodyInput: EndpointIO.Body[RAW, T] - ): F[DecodeBasicInputsResult] = { - requestBody.toRaw(request, bodyInput.bodyType).flatMap { v => - bodyInput.codec.decode(v.value) match { - case DecodeResult.Value(bodyV) => (values.setBodyInputValue(bodyV): DecodeBasicInputsResult).unit - case failure: DecodeResult.Failure => - v.createdFiles - .foldLeft(monad.unit(()))((u, f) => u.flatMap(_ => deleteFile(f.file))) - .map(_ => DecodeBasicInputsResult.Failure(bodyInput, failure): DecodeBasicInputsResult) - } - } - } - - private def unsupportedInputMediaTypeResponse( - request: ServerRequest, - oneOfBodyInput: EndpointIO.OneOfBody[_, _] - ): F[DecodeBasicInputsResult] = - (DecodeBasicInputsResult.Failure( - oneOfBodyInput, - DecodeResult - .Mismatch(oneOfBodyInput.variants.map(_.range.toString()).mkString(", or: "), request.contentType.getOrElse("")) - ): DecodeBasicInputsResult).unit - - private def defaultEndpointHandler(securityFailureResponse: => F[ServerResponse[B]]): EndpointHandler[F, B] = - new EndpointHandler[F, B] { - override def onDecodeSuccess[A, U, I]( - ctx: DecodeSuccessContext[F, A, U, I] - )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] = - ctx.serverEndpoint - .logic(implicitly)(ctx.principal)(ctx.input) - .flatMap { - case Right(result) => - responder(defaultSuccessStatusCode)(ctx.request, model.ValuedEndpointOutput(ctx.serverEndpoint.output, result)) - case Left(err) => - responder(defaultErrorStatusCode)(ctx.request, model.ValuedEndpointOutput(ctx.serverEndpoint.errorOutput, err)) - } - - override def onSecurityFailure[A]( - ctx: SecurityFailureContext[F, A] - )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[ServerResponse[B]] = securityFailureResponse - - override def onDecodeFailure( - ctx: DecodeFailureContext - )(implicit monad: MonadError[F], bodyListener: BodyListener[F, B]): F[Option[ServerResponse[B]]] = - (None: Option[ServerResponse[B]]).unit(monad) - } - - private def responder(defaultStatusCode: StatusCode): Responder[F, B] = new Responder[F, B] { - override def apply[O](request: ServerRequest, output: ValuedEndpointOutput[O]): F[ServerResponse[B]] = { - val outputValues = - new EncodeOutputs(toResponseBody, request.acceptsContentTypes.getOrElse(Nil)) - .apply(output.output, ParamsAsAny(output.value), OutputValues.empty) - val statusCode = outputValues.statusCode.getOrElse(defaultStatusCode) - - val headers = outputValues.headers - outputValues.body match { - case Some(bodyFromHeaders) => ServerResponse(statusCode, headers, Some(bodyFromHeaders(Headers(headers))), Some(output)).unit - case None => ServerResponse(statusCode, headers, None: Option[B], Some(output)).unit - } - } - } - - private val defaultSuccessStatusCode: StatusCode = StatusCode.Ok - private val defaultErrorStatusCode: StatusCode = StatusCode.BadRequest - - private case class ResultOrValue[T](v: F[Either[RequestResult[B], T]]) { - def flatMap[U](f: T => ResultOrValue[U]): ResultOrValue[U] = { - ResultOrValue(v.flatMap { - case Left(r) => (Left(r): Either[RequestResult[B], U]).unit - case Right(v) => f(v).v - }) - } - def map[U](f: T => U): ResultOrValue[U] = { - ResultOrValue(v.map { - case Left(r) => Left(r): Either[RequestResult[B], U] - case Right(v) => Right(f(v)) - }) - } - def fold(implicit ev: T =:= RequestResult[B]): F[RequestResult[B]] = v.map { - case Left(r) => r - case Right(r) => r - } - } - - private abstract class ResultOrValueFrom { - def apply(v: F[DecodeBasicInputsResult]): ResultOrValue[DecodeBasicInputsResult.Values] = ResultOrValue(v.flatMap { - case v: DecodeBasicInputsResult.Values => (Right(v): Either[RequestResult[B], DecodeBasicInputsResult.Values]).unit - case DecodeBasicInputsResult.Failure(input, failure) => onDecodeFailure(input, failure).map(Left(_)) - }) - def apply(v: InputValueResult): ResultOrValue[Params] = v match { - case InputValueResult.Value(params, _) => ResultOrValue((Right(params): Either[RequestResult[B], Params]).unit) - case InputValueResult.Failure(input, failure) => ResultOrValue(onDecodeFailure(input, failure).map(Left(_))) - } - def apply(v: DecodeBasicInputsResult): ResultOrValue[DecodeBasicInputsResult.Values] = v match { - case v: DecodeBasicInputsResult.Values => - ResultOrValue((Right(v): Either[RequestResult[B], DecodeBasicInputsResult.Values]).unit) - case DecodeBasicInputsResult.Failure(input, failure) => ResultOrValue(onDecodeFailure(input, failure).map(Left(_))) - } - def apply(f: Option[DecodeBasicInputsResult.Failure]): ResultOrValue[Unit] = f match { - case None => ResultOrValue((Right(()): Either[RequestResult[B], Unit]).unit) - case Some(DecodeBasicInputsResult.Failure(input, failure)) => ResultOrValue(onDecodeFailure(input, failure).map(Left(_))) - } - def value[T](v: F[T]): ResultOrValue[T] = ResultOrValue(v.map(Right(_))) - - def onDecodeFailure(input: EndpointInput[_], failure: DecodeResult.Failure): F[RequestResult[B]] - } -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/interpreter/ToResponseBody.scala b/tapir/core/src/main/scala/sttp/tapir/server/interpreter/ToResponseBody.scala deleted file mode 100644 index 4ced4adb..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/interpreter/ToResponseBody.scala +++ /dev/null @@ -1,14 +0,0 @@ -package sttp.tapir.server.interpreter - -import sttp.capabilities.Streams -import sttp.model.HasHeaders -import sttp.tapir.{CodecFormat, RawBodyType, WebSocketBodyOutput} - -import java.nio.charset.Charset - -trait ToResponseBody[B, S] { - val streams: Streams[S] - def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): B // TODO: remove headers? - def fromStreamValue(v: streams.BinaryStream, headers: HasHeaders, format: CodecFormat, charset: Option[Charset]): B - def fromWebSocketPipe[REQ, RESP](pipe: streams.Pipe[REQ, RESP], o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, S]): B -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/metrics/Metric.scala b/tapir/core/src/main/scala/sttp/tapir/server/metrics/Metric.scala deleted file mode 100644 index 8c90e9fb..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/metrics/Metric.scala +++ /dev/null @@ -1,66 +0,0 @@ -package sttp.tapir.server.metrics - -import sttp.monad.MonadError -import sttp.tapir.AnyEndpoint -import sttp.tapir.model.ServerRequest -import sttp.tapir.server.model.ServerResponse - -case class Metric[F[_], M]( - metric: M, - /** Called when the request starts. */ - onRequest: (ServerRequest, M, MonadError[F]) => F[EndpointMetric[F]] -) - -case class EndpointMetric[F[_]]( - /** Called when an endpoint matches the request, before calling the server logic. */ - onEndpointRequest: Option[AnyEndpoint => F[Unit]] = None, - /** Called when the response headers are ready (not necessarily the whole response body). */ - onResponseHeaders: Option[(AnyEndpoint, ServerResponse[_]) => F[Unit]] = None, - /** Called when the response body is complete. */ - onResponseBody: Option[(AnyEndpoint, ServerResponse[_]) => F[Unit]] = None, - onException: Option[(AnyEndpoint, Throwable) => F[Unit]] = None -) { - def onEndpointRequest(f: AnyEndpoint => F[Unit]): EndpointMetric[F] = this.copy(onEndpointRequest = Some(f)) - def onResponseHeaders(f: (AnyEndpoint, ServerResponse[_]) => F[Unit]): EndpointMetric[F] = this.copy(onResponseHeaders = Some(f)) - def onResponseBody(f: (AnyEndpoint, ServerResponse[_]) => F[Unit]): EndpointMetric[F] = this.copy(onResponseBody = Some(f)) - def onException(f: (AnyEndpoint, Throwable) => F[Unit]): EndpointMetric[F] = this.copy(onException = Some(f)) -} - -case class ResponsePhaseLabel(name: String, headersValue: String, bodyValue: String) -case class MetricLabels( - forRequest: List[(String, (AnyEndpoint, ServerRequest) => String)], - forResponse: List[(String, Either[Throwable, ServerResponse[_]] => String)], - forResponsePhase: ResponsePhaseLabel = ResponsePhaseLabel("phase", "headers", "body") -) { - def namesForRequest: List[String] = forRequest.map { case (name, _) => name } - def namesForResponse: List[String] = forResponse.map { case (name, _) => name } - - def valuesForRequest(ep: AnyEndpoint, req: ServerRequest): List[String] = forRequest.map { case (_, f) => f(ep, req) } - def valuesForResponse(res: ServerResponse[_]): List[String] = forResponse.map { case (_, f) => f(Right(res)) } - def valuesForResponse(ex: Throwable): List[String] = forResponse.map { case (_, f) => f(Left(ex)) } -} - -object MetricLabels { - - /** Labels request by path and method, response by status code */ - lazy val Default: MetricLabels = MetricLabels( - forRequest = List( - "path" -> { case (ep, _) => ep.showPathTemplate(showQueryParam = None) }, - "method" -> { case (_, req) => req.method.method } - ), - forResponse = List( - "status" -> { - case Right(r) => - r.code match { - case c if c.isInformational => "1xx" - case c if c.isSuccess => "2xx" - case c if c.isRedirect => "3xx" - case c if c.isClientError => "4xx" - case c if c.isServerError => "5xx" - case _ => "" - } - case Left(_) => "5xx" - } - ) - ) -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/model/ServerResponse.scala b/tapir/core/src/main/scala/sttp/tapir/server/model/ServerResponse.scala deleted file mode 100644 index 15299954..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/model/ServerResponse.scala +++ /dev/null @@ -1,20 +0,0 @@ -package sttp.tapir.server.model - -import sttp.model.{Header, Headers, ResponseMetadata, StatusCode} - -import scala.collection.immutable.Seq - -/** @param source The output, from which this response has been created. */ -case class ServerResponse[+B](code: StatusCode, headers: Seq[Header], body: Option[B], source: Option[ValuedEndpointOutput[_]]) - extends ResponseMetadata { - override def statusText: String = "" - override def toString: String = s"ServerResponse($code,${Headers.toStringSafe(headers)})" - - def showShort: String = code.toString() - def showCodeAndHeaders: String = s"$code (${Headers.toStringSafe(headers)})" - def addHeaders(additionalHeaders: Seq[Header]): ServerResponse[B] = copy(headers = headers ++ additionalHeaders) -} - -object ServerResponse { - def notFound[B]: ServerResponse[B] = ServerResponse[B](StatusCode.NotFound, Nil, None, None) -} diff --git a/tapir/core/src/main/scala/sttp/tapir/server/model/ValuedEndpointOutput.scala b/tapir/core/src/main/scala/sttp/tapir/server/model/ValuedEndpointOutput.scala deleted file mode 100644 index eaeaa499..00000000 --- a/tapir/core/src/main/scala/sttp/tapir/server/model/ValuedEndpointOutput.scala +++ /dev/null @@ -1,11 +0,0 @@ -package sttp.tapir.server.model - -import sttp.tapir.EndpointOutput - -case class ValuedEndpointOutput[T](output: EndpointOutput[T], value: T) { - def prepend[U](otherOutput: EndpointOutput[U], otherValue: U): ValuedEndpointOutput[(U, T)] = - ValuedEndpointOutput(otherOutput.and(output), (otherValue, value)) - - def append[U](otherOutput: EndpointOutput[U], otherValue: U): ValuedEndpointOutput[(T, U)] = - ValuedEndpointOutput(output.and(otherOutput), (value, otherValue)) -} diff --git a/tapir/playjson/shared/src/main/scala/sttp/tapir/json/play/TapirJsonPlay.scala b/tapir/playjson/shared/src/main/scala/sttp/tapir/json/play/TapirJsonPlay.scala deleted file mode 100644 index fa4e2733..00000000 --- a/tapir/playjson/shared/src/main/scala/sttp/tapir/json/play/TapirJsonPlay.scala +++ /dev/null @@ -1,49 +0,0 @@ -package sttp.tapir.json.play - -import play.api.libs.json._ -import sttp.tapir._ -import sttp.tapir.SchemaType._ -import sttp.tapir.Codec.JsonCodec -import sttp.tapir.DecodeResult.Error.{JsonDecodeException, JsonError} -import sttp.tapir.DecodeResult.{Error, Value} -import sttp.tapir.Schema.SName - -import scala.util.{Failure, Success, Try} - -trait TapirJsonPlay { - def jsonBody[T: Reads: Writes: Schema]: EndpointIO.Body[String, T] = stringBodyUtf8AnyFormat(readsWritesCodec[T]) - - def jsonBodyWithRaw[T: Reads: Writes: Schema]: EndpointIO.Body[String, (String, T)] = stringBodyUtf8AnyFormat( - implicitly[JsonCodec[(String, T)]] - ) - - def jsonQuery[T: Reads: Writes: Schema](name: String): EndpointInput.Query[T] = - queryAnyFormat[T, CodecFormat.Json](name, Codec.jsonQuery(readsWritesCodec)) - - implicit def readsWritesCodec[T: Reads: Writes: Schema]: JsonCodec[T] = - Codec.json[T] { s => - Try(Json.parse(s)) match { - case Failure(exception) => - Error(s, JsonDecodeException(List.empty, exception)) - case Success(jsValue) => - implicitly[Reads[T]].reads(jsValue) match { - case JsError(errors) => - val jsonErrors = errors - .flatMap { case (path, validationErrors) => - val fields = path.toJsonString.split("\\.").toList.map(FieldName.apply) - validationErrors.map(error => fields -> error) - } - .map { case (fields, validationError) => - JsonError(validationError.message, fields) - } - .toList - Error(s, JsonDecodeException(jsonErrors, JsResultException(errors))) - case JsSuccess(value, _) => - Value(value) - } - } - } { t => Json.stringify(Json.toJson(t)) } - - implicit val schemaForPlayJsValue: Schema[JsValue] = Schema.any - implicit val schemaForPlayJsObject: Schema[JsObject] = Schema.anyObject[JsObject].name(SName("play.api.libs.json.JsObject")) -} diff --git a/tapir/playjson/shared/src/main/scala/sttp/tapir/json/play/package.scala b/tapir/playjson/shared/src/main/scala/sttp/tapir/json/play/package.scala deleted file mode 100644 index be8e9530..00000000 --- a/tapir/playjson/shared/src/main/scala/sttp/tapir/json/play/package.scala +++ /dev/null @@ -1,3 +0,0 @@ -package sttp.tapir.json - -package object play extends TapirJsonPlay diff --git a/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/PlayBodyListener.scala b/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/PlayBodyListener.scala deleted file mode 100644 index 699ecc8f..00000000 --- a/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/PlayBodyListener.scala +++ /dev/null @@ -1,30 +0,0 @@ -package sttp.tapir.server.play - -import akka.Done -import play.api.http.HttpEntity -import sttp.tapir.server.interpreter.BodyListener - -import scala.concurrent.{ExecutionContext, Future} -import scala.util.{Failure, Success, Try} - -class PlayBodyListener(implicit ec: ExecutionContext) extends BodyListener[Future, PlayResponseBody] { - override def onComplete(body: PlayResponseBody)(cb: Try[Unit] => Future[Unit]): Future[PlayResponseBody] = { - - def onDone(f: Future[Done]): Unit = f.onComplete { - case Failure(ex) => cb(Failure(ex)) - case _ => cb(Success(())) - } - - body match { - case ws @ Left(_) => cb(Success(())).map(_ => ws) - case Right(r) => - (r match { - case e @ HttpEntity.Streamed(data, _, _) => - Future.successful(e.copy(data = data.watchTermination() { case (_, f) => onDone(f) })) - case e @ HttpEntity.Chunked(chunks, _) => - Future.successful(e.copy(chunks = chunks.watchTermination() { case (_, f) => onDone(f) })) - case e @ HttpEntity.Strict(_, _) => cb(Success(())).map(_ => e) - }).map(Right(_)) - } - } -} diff --git a/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/PlayRequestBody.scala b/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/PlayRequestBody.scala deleted file mode 100644 index ea294217..00000000 --- a/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/PlayRequestBody.scala +++ /dev/null @@ -1,139 +0,0 @@ -package sttp.tapir.server.play - -import akka.stream.Materializer -import akka.stream.scaladsl.{FileIO, Source} -import akka.util.ByteString -import play.api.mvc.{Request, Result} -import play.core.parsers.Multipart -import sttp.capabilities.akka.AkkaStreams -import sttp.model.{Header, MediaType, Part} -import sttp.tapir.internal._ -import sttp.tapir.model.ServerRequest -import sttp.tapir.server.interpreter.{RawValue, RequestBody} -import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, RawPart} - -import java.io.{ByteArrayInputStream, File} -import java.nio.charset.Charset -import scala.concurrent.{ExecutionContext, Future} -import scala.collection.compat._ - -private[play] class PlayRequestBody(serverOptions: PlayServerOptions)(implicit - mat: Materializer -) extends RequestBody[Future, AkkaStreams] { - - override val streams: AkkaStreams = AkkaStreams - - override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): Future[RawValue[R]] = { - import mat.executionContext - val request = playRequest(serverRequest) - val charset = request.charset.map(Charset.forName) - toRaw(request, bodyType, charset, () => request.body, None) - } - - override def toStream(serverRequest: ServerRequest): streams.BinaryStream = playRequest(serverRequest).body - - private def toRaw[R]( - request: Request[AkkaStreams.BinaryStream], - bodyType: RawBodyType[R], - charset: Option[Charset], - body: () => Source[ByteString, Any], - bodyAsFile: Option[File] - )(implicit - mat: Materializer, - ec: ExecutionContext - ): Future[RawValue[R]] = { - // playBodyParsers is used, so that the maxLength limits from Play configuration are applied - def bodyAsByteString(): Future[ByteString] = { - serverOptions.playBodyParsers.byteString.apply(request).run(body()).flatMap { - case Left(result) => Future.failed(new PlayBodyParserException(result)) - case Right(value) => Future.successful(value) - } - } - bodyType match { - case RawBodyType.StringBody(defaultCharset) => - bodyAsByteString().map(b => RawValue(b.decodeString(charset.getOrElse(defaultCharset)))) - case RawBodyType.ByteArrayBody => bodyAsByteString().map(b => RawValue(b.toArray)) - case RawBodyType.ByteBufferBody => bodyAsByteString().map(b => RawValue(b.toByteBuffer)) - case RawBodyType.InputStreamBody => bodyAsByteString().map(b => RawValue(new ByteArrayInputStream(b.toArray))) - case RawBodyType.InputStreamRangeBody => - bodyAsByteString().map(b => RawValue(new InputStreamRange(() => new ByteArrayInputStream(b.toArray)))) - case RawBodyType.FileBody => - bodyAsFile match { - case Some(file) => - val tapirFile = FileRange(file) - Future.successful(RawValue(tapirFile, Seq(tapirFile))) - case None => - val file = FileRange(serverOptions.temporaryFileCreator.create().toFile) - serverOptions.playBodyParsers.file(file.file).apply(request).run(body()).flatMap { - case Left(result) => Future.failed(new PlayBodyParserException(result)) - case Right(_) => Future.successful(RawValue(file, Seq(file))) - } - } - case m: RawBodyType.MultipartBody => multiPartRequestToRawBody(request, m, body) - } - } - - private def multiPartRequestToRawBody( - request: Request[AkkaStreams.BinaryStream], - m: RawBodyType.MultipartBody, - body: () => Source[ByteString, Any] - )(implicit - mat: Materializer, - ec: ExecutionContext - ): Future[RawValue[Seq[RawPart]]] = { - val bodyParser = serverOptions.playBodyParsers.multipartFormData( - Multipart.handleFilePartAsTemporaryFile(serverOptions.temporaryFileCreator) - ) - bodyParser.apply(request).run(body()).flatMap { - case Left(r) => - Future.failed(new PlayBodyParserException(r)) - case Right(value) => - val dataParts: Seq[Future[Option[Part[Any]]]] = - value.dataParts.flatMap { case (key, value: scala.collection.Seq[String]) => - m.partType(key).map { partType => - val data = value.map(ByteString.apply).to(scala.collection.immutable.Seq) - val contentLength = Header.contentLength(data.map(_.length.toLong).sum) - toRaw( - request.withHeaders(request.headers.replace(contentLength.name -> contentLength.value)), - partType, - charset(partType), - () => Source(data), - None - ).map(body => Some(Part(key, body.value))) - } - }.toSeq - - val fileParts: Seq[Future[Option[Part[Any]]]] = value.files.map { f => - m.partType(f.key) - .map { partType => - toRaw( - request, - partType, - charset(partType), - () => FileIO.fromPath(f.ref.path), - Some(f.ref.toFile) - ).map(body => - Some( - Part( - f.key, - body.value, - Map(f.key -> f.dispositionType, Part.FileNameDispositionParam -> f.filename), - f.contentType.flatMap(MediaType.parse(_).toOption).map(Header.contentType).toList - ) - .asInstanceOf[RawPart] - ) - ) - } - .getOrElse { - serverOptions.deleteFile(f.ref.toFile).map(_ => Option.empty) - } - } - Future.sequence(dataParts ++ fileParts).map(ps => ps.collect { case Some(p) => p }).map(RawValue.fromParts) - } - } - - private def playRequest(serverRequest: ServerRequest) = - serverRequest.underlying.asInstanceOf[Request[Source[ByteString, Any]]] -} - -class PlayBodyParserException(val result: Result) extends Exception diff --git a/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/PlayServerInterpreter.scala b/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/PlayServerInterpreter.scala deleted file mode 100644 index e53fffbf..00000000 --- a/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/PlayServerInterpreter.scala +++ /dev/null @@ -1,157 +0,0 @@ -package sttp.tapir.server.play - -import akka.stream.Materializer -import akka.stream.scaladsl.{Flow, Source} -import akka.util.ByteString -import play.api.http.websocket.Message -import play.api.http.{HeaderNames, HttpEntity} -import play.api.libs.streams.Accumulator -import play.api.mvc._ -import play.api.routing.Router.Routes -import sttp.capabilities.WebSockets -import sttp.capabilities.akka.AkkaStreams -import sttp.model.Method -import sttp.monad.FutureMonad -import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.interceptor.RequestResult -import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, ServerInterpreter} -import sttp.tapir.server.model.ServerResponse - -import scala.concurrent.{ExecutionContext, Future} - -trait PlayServerInterpreter { - - implicit def mat: Materializer - - implicit def executionContext: ExecutionContext = mat.executionContext - - def playServerOptions: PlayServerOptions = PlayServerOptions.default - - private val streamParser: BodyParser[AkkaStreams.BinaryStream] = BodyParser { _ => - Accumulator.source[ByteString].map(Right.apply) - } - - def toRoutes(e: ServerEndpoint[AkkaStreams with WebSockets, Future]): Routes = { - toRoutes(List(e)) - } - - def toRoutes( - serverEndpoints: List[ServerEndpoint[AkkaStreams with WebSockets, Future]] - ): Routes = { - implicit val monad: FutureMonad = new FutureMonad() - - val filterServerEndpoints = FilterServerEndpoints(serverEndpoints) - val singleEndpoint = serverEndpoints.size == 1 - - new PartialFunction[RequestHeader, Handler] { - override def isDefinedAt(request: RequestHeader): Boolean = { - val filtered = filterServerEndpoints(PlayServerRequest(request, request)) - if (singleEndpoint) { - // If we are interpreting a single endpoint, we verify that the method matches as well; in case it doesn't, - // we refuse to handle the request, allowing other Play routes to handle it. Otherwise even if the method - // doesn't match, this will be handled by the RejectInterceptor - filtered.exists { e => - val m = e.endpoint.method - m.isEmpty || m.contains(Method(request.method)) - } - } else { - filtered.nonEmpty - } - } - - override def apply(header: RequestHeader): Handler = - if (isWebSocket(header)) - WebSocket.acceptOrResult { header => - getResponse(header, header.withBody(Source.empty)) - } - else - playServerOptions.defaultActionBuilder.async(streamParser) { request => - getResponse(header, request).flatMap { - case Left(result) => Future.successful(result) - case Right(_) => Future.failed(new Exception("Only WebSocket requests accept flows.")) - } - } - - private def getResponse( - header: RequestHeader, - request: Request[AkkaStreams.BinaryStream] - ): Future[Either[Result, Flow[Message, Message, Any]]] = { - implicit val bodyListener: BodyListener[Future, PlayResponseBody] = new PlayBodyListener - val serverRequest = PlayServerRequest(header, request) - val interpreter = new ServerInterpreter( - filterServerEndpoints, - new PlayRequestBody(playServerOptions), - new PlayToResponseBody, - playServerOptions.interceptors, - playServerOptions.deleteFile - ) - - interpreter(serverRequest) - .map { - case RequestResult.Failure(_) => - throw new RuntimeException( - s"The path: ${request.path} matches the shape of some endpoint, but none of the " + - s"endpoints decoded the request successfully, and the decode failure handler didn't provide a " + - s"response. Play requires that if the path shape matches some endpoints, the request " + - s"should be handled by tapir." - ) - case RequestResult.Response(response: ServerResponse[PlayResponseBody]) => - val headers: Map[String, String] = response.headers - .foldLeft(Map.empty[String, List[String]]) { (a, b) => - if (a.contains(b.name)) a + (b.name -> (a(b.name) :+ b.value)) else a + (b.name -> List(b.value)) - } - .map { - // See comment in play.api.mvc.CookieHeaderEncoding - case (key, value) if key == HeaderNames.SET_COOKIE => (key, value.mkString(";;")) - case (key, value) => (key, value.mkString(", ")) - } - .filterNot(allowToSetExplicitly) - - val status = response.code.code - response.body match { - case Some(Left(flow)) => Right(flow) - case Some(Right(entity)) => Left(Result(ResponseHeader(status, headers), entity)) - case None => - if (serverRequest.method.is(Method.HEAD) && response.contentLength.isDefined) - Left( - Result( - ResponseHeader(status, headers), - HttpEntity.Streamed(Source.empty, response.contentLength, response.contentType) - ) - ) - else Left(Result(ResponseHeader(status, headers), HttpEntity.Strict(ByteString.empty, response.contentType))) - } - } - .recover { case e: PlayBodyParserException => - Left(e.result) - } - } - } - } - - private def isWebSocket(header: RequestHeader): Boolean = - (for { - connection <- header.headers.get(sttp.model.HeaderNames.Connection) - upgrade <- header.headers.get(sttp.model.HeaderNames.Upgrade) - } yield connection.equalsIgnoreCase("Upgrade") && upgrade.equalsIgnoreCase("websocket")).getOrElse(false) - - private def allowToSetExplicitly(header: (String, String)): Boolean = - List(HeaderNames.CONTENT_TYPE, HeaderNames.CONTENT_LENGTH, HeaderNames.TRANSFER_ENCODING).contains(header._1) - -} - -object PlayServerInterpreter { - def apply()(implicit _mat: Materializer): PlayServerInterpreter = { - new PlayServerInterpreter { - override implicit def mat: Materializer = _mat - } - } - - def apply(serverOptions: PlayServerOptions)(implicit _mat: Materializer): PlayServerInterpreter = { - new PlayServerInterpreter { - override implicit def mat: Materializer = _mat - - override def playServerOptions: PlayServerOptions = serverOptions - } - } -} diff --git a/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/PlayServerOptions.scala b/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/PlayServerOptions.scala deleted file mode 100644 index 362e5ba2..00000000 --- a/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/PlayServerOptions.scala +++ /dev/null @@ -1,80 +0,0 @@ -package sttp.tapir.server.play - -import akka.stream.Materializer -import com.typesafe.config.ConfigFactory -import play.api.http.ParserConfiguration -import play.api.Logger -import play.api.libs.Files.{SingletonTemporaryFileCreator, TemporaryFileCreator} -import play.api.mvc._ -import sttp.tapir.{Defaults, TapirFile} -import sttp.tapir.server.interceptor.decodefailure.DecodeFailureHandler -import sttp.tapir.server.interceptor.log.DefaultServerLog -import sttp.tapir.server.interceptor.{CustomiseInterceptors, Interceptor} - -import scala.concurrent.{ExecutionContext, Future, blocking} - -case class PlayServerOptions( - temporaryFileCreator: TemporaryFileCreator, - deleteFile: TapirFile => Future[Unit], - defaultActionBuilder: ActionBuilder[Request, AnyContent], - playBodyParsers: PlayBodyParsers, - decodeFailureHandler: DecodeFailureHandler, - interceptors: List[Interceptor[Future]] -) { - def prependInterceptor(i: Interceptor[Future]): PlayServerOptions = copy(interceptors = i :: interceptors) - def appendInterceptor(i: Interceptor[Future]): PlayServerOptions = copy(interceptors = interceptors :+ i) -} - -object PlayServerOptions { - - /** Allows customising the interceptors used by the server interpreter. */ - def customiseInterceptors(conf: ParserConfiguration = defaultParserConfiguration)(implicit - mat: Materializer, - ec: ExecutionContext - ): CustomiseInterceptors[Future, PlayServerOptions] = - CustomiseInterceptors( - createOptions = (ci: CustomiseInterceptors[Future, PlayServerOptions]) => - PlayServerOptions( - SingletonTemporaryFileCreator, - defaultDeleteFile(_), - DefaultActionBuilder.apply(PlayBodyParsers.apply(conf = conf).anyContent), - PlayBodyParsers.apply(conf = conf), - ci.decodeFailureHandler, - ci.interceptors - ) - ).serverLog(defaultServerLog) - - def defaultDeleteFile(file: TapirFile)(implicit ec: ExecutionContext): Future[Unit] = { - Future(blocking(Defaults.deleteFile()(file))) - } - - lazy val defaultServerLog: DefaultServerLog[Future] = { - DefaultServerLog( - doLogWhenReceived = debugLog(_, None), - doLogWhenHandled = debugLog, - doLogAllDecodeFailures = debugLog, - doLogExceptions = (msg: String, ex: Throwable) => Future.successful { logger.error(msg, ex) }, - noLog = Future.successful(()) - ) - } - - private def debugLog(msg: String, exOpt: Option[Throwable]): Future[Unit] = Future.successful { - exOpt match { - case None => logger.debug(msg) - case Some(ex) => logger.debug(s"$msg; exception: {}", ex) - } - } - - def default(implicit mat: Materializer, ec: ExecutionContext): PlayServerOptions = customiseInterceptors().options - - private lazy val conf = ConfigFactory.load - - lazy val defaultParserConfiguration = { - ParserConfiguration( - maxMemoryBuffer = conf.getMemorySize("play.http.parser.maxMemoryBuffer").toBytes, - maxDiskBuffer = conf.getMemorySize("play.http.parser.maxDiskBuffer").toBytes - ) - } - - lazy val logger: Logger = Logger(this.getClass.getPackage.getName) -} diff --git a/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/PlayServerRequest.scala b/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/PlayServerRequest.scala deleted file mode 100644 index 92aa085c..00000000 --- a/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/PlayServerRequest.scala +++ /dev/null @@ -1,34 +0,0 @@ -package sttp.tapir.server.play - -import play.api.mvc.RequestHeader -import play.utils.UriEncoding -import sttp.model.{Header, Method, QueryParams, Uri} -import sttp.tapir.{AttributeKey, AttributeMap} -import sttp.tapir.model.{ConnectionInfo, ServerRequest} - -import java.nio.charset.StandardCharsets -import scala.collection.immutable._ - -private[play] case class PlayServerRequest( - requestHeader: RequestHeader, - requestWithContext: RequestHeader, - attributes: AttributeMap = AttributeMap.Empty -) extends ServerRequest { - override lazy val method: Method = Method(requestHeader.method.toUpperCase) - override def protocol: String = requestHeader.version - override lazy val uri: Uri = Uri.unsafeParse(requestHeader.uri) - override lazy val connectionInfo: ConnectionInfo = ConnectionInfo(None, None, Some(requestHeader.secure)) - override lazy val headers: Seq[Header] = requestHeader.headers.headers.map { case (k, v) => Header(k, v) }.toList - override lazy val queryParameters: QueryParams = QueryParams.fromMultiMap(requestHeader.queryString) - override lazy val pathSegments: List[String] = { - val segments = requestHeader.path.dropWhile(_ == '/').split("/").toList.map(UriEncoding.decodePathSegment(_, StandardCharsets.UTF_8)) - if (segments == List("")) Nil else segments // representing the root path as an empty list - } - - override def attribute[T](k: AttributeKey[T]): Option[T] = attributes.get(k) - override def attribute[T](k: AttributeKey[T], v: T): PlayServerRequest = copy(attributes = attributes.put(k, v)) - - override def underlying: Any = requestWithContext - override def withUnderlying(underlying: Any): ServerRequest = - new PlayServerRequest(requestHeader, requestWithContext = underlying.asInstanceOf[RequestHeader], attributes) -} diff --git a/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/PlayToResponseBody.scala b/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/PlayToResponseBody.scala deleted file mode 100644 index dbcb5abb..00000000 --- a/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/PlayToResponseBody.scala +++ /dev/null @@ -1,202 +0,0 @@ -package sttp.tapir.server.play - -import akka.NotUsed -import akka.stream.scaladsl.{FileIO, Source, StreamConverters} -import akka.util.ByteString -import play.api.http.{HeaderNames, HttpChunk, HttpEntity} -import play.api.mvc.MultipartFormData.{DataPart, FilePart} -import play.api.mvc.{Codec, MultipartFormData} -import sttp.capabilities.akka.AkkaStreams -import sttp.model.{HasHeaders, Part} -import sttp.tapir.server.interpreter.ToResponseBody -import sttp.tapir.{CodecFormat, FileRange, RawBodyType, RawPart, WebSocketBodyOutput} - -import java.nio.ByteBuffer -import java.nio.charset.Charset - -class PlayToResponseBody extends ToResponseBody[PlayResponseBody, AkkaStreams] { - - override val streams: AkkaStreams = AkkaStreams - - override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): PlayResponseBody = { - Right(fromRawValue(v, headers, bodyType)) - } - - private val ChunkSize = 8192 - - private def fromRawValue[R](v: R, headers: HasHeaders, bodyType: RawBodyType[R]): HttpEntity = { - val contentType = headers.contentType - bodyType match { - case RawBodyType.StringBody(charset) => - val str = v.asInstanceOf[String] - HttpEntity.Strict(ByteString(str, charset), contentType) - - case RawBodyType.ByteArrayBody => - val bytes = v.asInstanceOf[Array[Byte]] - HttpEntity.Strict(ByteString(bytes), contentType) - - case RawBodyType.ByteBufferBody => - val byteBuffer = v.asInstanceOf[ByteBuffer] - HttpEntity.Strict(ByteString(byteBuffer), contentType) - - case RawBodyType.InputStreamBody => - streamOrChunk(StreamConverters.fromInputStream(() => v), headers.contentLength, contentType) - - case RawBodyType.InputStreamRangeBody => - val initialStream = StreamConverters.fromInputStream(v.inputStreamFromRangeStart, ChunkSize) - v.range - .map(r => streamOrChunk(toRangedStream(initialStream, bytesTotal = r.contentLength), Some(r.contentLength), contentType)) - .getOrElse(streamOrChunk(initialStream, headers.contentLength, contentType)) - - case RawBodyType.FileBody => - v.range - .flatMap(r => - r.startAndEnd - .map(s => streamOrChunk(createFileSource(v, s._1, r.contentLength), Some(r.contentLength), contentType)) - ) - .getOrElse(streamOrChunk(FileIO.fromPath(v.file.toPath), Some(v.file.length()), contentType)) - - case m: RawBodyType.MultipartBody => - val rawParts = v.asInstanceOf[Seq[RawPart]] - - val dataParts = rawParts - .filter { part => - m.partType(part.name).exists { - case RawBodyType.StringBody(_) => true - case RawBodyType.ByteArrayBody => true - case RawBodyType.ByteBufferBody => true - case _ => false - } - } - .flatMap(rawPartsToDataPart(m, _)) - - val fileParts = rawParts - .filter { part => - m.partType(part.name).exists { - case RawBodyType.InputStreamBody => true - case RawBodyType.FileBody => true - case _ => false - } - } - .flatMap(rawPartsToFilePart(m, _)) - - HttpEntity.Streamed(multipartFormToStream(dataParts, fileParts), None, contentType) - } - } - - private def createFileSource( - tapirFile: FileRange, - start: Long, - bytesTotal: Long - ): AkkaStreams.BinaryStream = - toRangedStream(FileIO.fromPath(tapirFile.file.toPath, ChunkSize, startPosition = start), bytesTotal) - - private def toRangedStream(initialStream: AkkaStreams.BinaryStream, bytesTotal: Long): AkkaStreams.BinaryStream = - initialStream - .scan((0L, ByteString.empty)) { case ((bytesConsumed, _), next) => - val bytesInNext = next.length - val bytesFromNext = Math.max(0, Math.min(bytesTotal - bytesConsumed, bytesInNext.toLong)) - (bytesConsumed + bytesInNext, next.take(bytesFromNext.toInt)) - } - .takeWhile(_._1 < bytesTotal, inclusive = true) - .map(_._2) - - override def fromStreamValue( - v: streams.BinaryStream, - headers: HasHeaders, - format: CodecFormat, - charset: Option[Charset] - ): PlayResponseBody = { - Right(streamOrChunk(v, headers.contentLength, Option(headers.contentType.getOrElse(formatToContentType(format, charset))))) - } - - private def streamOrChunk(stream: streams.BinaryStream, contentLength: Option[Long], contentType: Option[String]): HttpEntity = { - contentLength match { - case Some(length) => - HttpEntity.Streamed(stream, Some(length), contentType) - case None => - val chunkStream = stream.map(HttpChunk.Chunk.apply) - HttpEntity.Chunked(chunkStream, contentType) - } - } - - override def fromWebSocketPipe[REQ, RESP]( - pipe: streams.Pipe[REQ, RESP], - o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, AkkaStreams] - ): PlayResponseBody = Left(PlayWebSockets.pipeToBody(pipe, o)) - - private def rawPartsToFilePart[T]( - m: RawBodyType.MultipartBody, - part: Part[T] - ): Option[MultipartFormData.FilePart[Source[ByteString, _]]] = { - m.partType(part.name).flatMap { partType => - val entity: HttpEntity = fromRawValue(part.body, part, partType.asInstanceOf[RawBodyType[Any]]) - - for { - fileName <- part.fileName - contentLength <- entity.contentLength - dispositionType <- part.otherDispositionParams.get(part.name) - } yield MultipartFormData.FilePart(part.name, fileName, entity.contentType, entity.dataStream, contentLength, dispositionType) - } - } - - private def rawPartsToDataPart[T](m: RawBodyType.MultipartBody, part: Part[T]): Option[MultipartFormData.DataPart] = { - m.partType(part.name).flatMap { partType => - val charset = partType match { - case valueType: RawBodyType.StringBody => valueType.charset - case _ => Charset.defaultCharset() - } - - val maybeData: Option[String] = - fromRawValue(part.body, part, partType.asInstanceOf[RawBodyType[Any]]) match { - case HttpEntity.Strict(data, _) => Some(data.decodeString(charset)) - case HttpEntity.Streamed(_, _, _) => None - case HttpEntity.Chunked(_, _) => None - } - - maybeData.map(MultipartFormData.DataPart(part.name, _)) - } - } - - private def formatToContentType(format: CodecFormat, charset: Option[Charset]): String = - charset.fold(format.mediaType)(format.mediaType.charset(_)).toString() - - private def multipartFormToStream[A]( - dataParts: Seq[DataPart], - fileParts: Seq[FilePart[Source[ByteString, _]]] - ): Source[ByteString, NotUsed] = { - val boundary: String = "--------" + scala.util.Random.alphanumeric.take(20).mkString("") - - def formatDataParts(dataParts: Seq[DataPart]) = { - val result = dataParts - .flatMap { case DataPart(name, value) => - s""" - --$boundary\r\n${HeaderNames.CONTENT_DISPOSITION}: form-data; name="$name"\r\n\r\n$value\r\n - """.stripMargin - } - .mkString("") - Codec.utf_8.encode(result) - } - - def filePartHeader(file: FilePart[_]) = { - val name = s""""${file.key}"""" - val filename = s""""${file.filename}"""" - val contentType = file.contentType - .map { ct => s"${HeaderNames.CONTENT_TYPE}: $ct\r\n" } - .getOrElse("") - Codec.utf_8.encode( - s"--$boundary\r\n${HeaderNames.CONTENT_DISPOSITION}: form-data; name=$name; filename=$filename\r\n$contentType\r\n" - ) - } - - Source - .single(formatDataParts(dataParts)) - .concat(Source(fileParts.toList).flatMapConcat { file => - Source - .single(filePartHeader(file)) - .concat(file.ref) - .concat(Source.single(ByteString("\r\n", Charset.forName("UTF-8")))) - .concat(Source.single(ByteString(s"--$boundary--", "UTF-8"))) - }) - } -} diff --git a/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/PlayWebSockets.scala b/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/PlayWebSockets.scala deleted file mode 100644 index dbf5d618..00000000 --- a/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/PlayWebSockets.scala +++ /dev/null @@ -1,54 +0,0 @@ -package sttp.tapir.server.play - -import akka.stream.scaladsl.Flow -import akka.util.ByteString -import play.api.http.websocket._ -import sttp.capabilities.akka.AkkaStreams -import sttp.tapir.model.WebSocketFrameDecodeFailure -import sttp.tapir.{DecodeResult, WebSocketBodyOutput} -import sttp.ws.WebSocketFrame - -private[play] object PlayWebSockets { - def pipeToBody[REQ, RESP]( - pipe: Flow[REQ, RESP, Any], - o: WebSocketBodyOutput[Flow[REQ, RESP, Any], REQ, RESP, _, AkkaStreams] - ): Flow[Message, Message, Any] = { - Flow[Message] - .map(messageToFrame) - .collect { case data: WebSocketFrame.Data[_] => - data - } - .map(f => - o.requests.decode(f) match { - case failure: DecodeResult.Failure => throw new WebSocketFrameDecodeFailure(f, failure) - case DecodeResult.Value(v) => v - } - ) - .via(pipe) - .map(o.responses.encode) - .takeWhile { - case WebSocketFrame.Close(_, _) => false - case _ => true - } - .mapConcat(frameToMessage(_).toList) - } - - private def messageToFrame(m: Message): WebSocketFrame = - m match { - case msg: TextMessage => WebSocketFrame.text(msg.data) - case msg: BinaryMessage => WebSocketFrame.binary(msg.data.toArray) - case msg: PingMessage => WebSocketFrame.Ping(msg.data.toArray) - case msg: PongMessage => WebSocketFrame.Pong(msg.data.toArray) - case msg: CloseMessage => WebSocketFrame.Close(msg.statusCode.getOrElse(WebSocketFrame.close.statusCode), msg.reason) - } - - private def frameToMessage(w: WebSocketFrame): Option[Message] = { - w match { - case WebSocketFrame.Text(p, _, _) => Some(TextMessage(p)) - case WebSocketFrame.Binary(p, _, _) => Some(BinaryMessage(ByteString(p))) - case WebSocketFrame.Ping(p) => Some(PingMessage(ByteString(p))) - case WebSocketFrame.Pong(p) => Some(PongMessage(ByteString(p))) - case WebSocketFrame.Close(code, text) => Some(CloseMessage(code, text)) - } - } -} diff --git a/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/package.scala b/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/package.scala deleted file mode 100644 index 3617b0ff..00000000 --- a/tapir/tapir-play/src/main/scala/sttp/tapir/server/play/package.scala +++ /dev/null @@ -1,9 +0,0 @@ -package sttp.tapir.server - -import _root_.play.api.http.HttpEntity -import _root_.play.api.http.websocket.Message -import akka.stream.scaladsl.Flow - -package object play { - type PlayResponseBody = Either[Flow[Message, Message, Any], HttpEntity] -}