diff --git a/README.md b/README.md index ba450e7c..4cf8382f 100644 --- a/README.md +++ b/README.md @@ -40,12 +40,12 @@ import com.avast.grpc.jsonbridge.ReflectionGrpcJsonBridge // for whole server val grpcServer: io.grpc.Server = ??? -val bridge = new ReflectionGrpcJsonBridge[Task](grpcServer) +val bridge = ReflectionGrpcJsonBridge.createFromServer[Task](grpcServer) // or for selected services val s1: ServerServiceDefinition = ??? val s2: ServerServiceDefinition = ??? -val anotherBridge = new ReflectionGrpcJsonBridge[Task](s1, s2) +val anotherBridge = ReflectionGrpcJsonBridge.createFromServices[Task](s1, s2) // call a method manually, with a header specified val jsonResponse = bridge.invoke("com.avast.grpc.jsonbridge.test.TestService/Add", """ { "a": 1, "b": 2} """, Map("My-Header" -> "value")) diff --git a/akka-http/src/test/scala/com/avast/grpc/jsonbridge/akkahttp/AkkaHttpTest.scala b/akka-http/src/test/scala/com/avast/grpc/jsonbridge/akkahttp/AkkaHttpTest.scala index b6d70e6d..1024dd2f 100644 --- a/akka-http/src/test/scala/com/avast/grpc/jsonbridge/akkahttp/AkkaHttpTest.scala +++ b/akka-http/src/test/scala/com/avast/grpc/jsonbridge/akkahttp/AkkaHttpTest.scala @@ -7,14 +7,24 @@ import akka.http.scaladsl.testkit.ScalatestRouteTest import cats.data.NonEmptyList import cats.effect.IO import com.avast.grpc.jsonbridge._ +import io.grpc.ServerServiceDefinition import org.scalatest.FunSuite +import scala.concurrent.ExecutionContext import scala.util.Random class AkkaHttpTest extends FunSuite with ScalatestRouteTest { + val ec: ExecutionContext = implicitly[ExecutionContext] + def bridge(ssd: ServerServiceDefinition): GrpcJsonBridge[IO] = + ReflectionGrpcJsonBridge + .createFromServices[IO](ec)(ssd) + .allocated + .unsafeRunSync() + ._1 + test("basic") { - val route = AkkaHttp[IO](Configuration.Default)(new ReflectionGrpcJsonBridge[IO](TestServiceImpl.bindService())) + val route = AkkaHttp[IO](Configuration.Default)(bridge(TestServiceImpl.bindService())) Post("/com.avast.grpc.jsonbridge.test.TestService/Add", """ { "a": 1, "b": 2} """) .withHeaders(AkkaHttp.JsonContentType) ~> route ~> check { assertResult(StatusCodes.OK)(status) @@ -25,7 +35,7 @@ class AkkaHttpTest extends FunSuite with ScalatestRouteTest { test("with path prefix") { val configuration = Configuration.Default.copy(pathPrefix = Some(NonEmptyList.of("abc", "def"))) - val route = AkkaHttp[IO](configuration)(new ReflectionGrpcJsonBridge[IO](TestServiceImpl.bindService())) + val route = AkkaHttp[IO](configuration)(bridge(TestServiceImpl.bindService())) Post("/abc/def/com.avast.grpc.jsonbridge.test.TestService/Add", """ { "a": 1, "b": 2} """) .withHeaders(AkkaHttp.JsonContentType) ~> route ~> check { assertResult(StatusCodes.OK)(status) @@ -34,7 +44,7 @@ class AkkaHttpTest extends FunSuite with ScalatestRouteTest { } test("bad request after wrong request") { - val route = AkkaHttp[IO](Configuration.Default)(new ReflectionGrpcJsonBridge[IO](TestServiceImpl.bindService())) + val route = AkkaHttp[IO](Configuration.Default)(bridge(TestServiceImpl.bindService())) // empty body Post("/com.avast.grpc.jsonbridge.test.TestService/Add", "") .withHeaders(AkkaHttp.JsonContentType) ~> route ~> check { @@ -47,7 +57,7 @@ class AkkaHttpTest extends FunSuite with ScalatestRouteTest { } test("propagates user-specified status") { - val route = AkkaHttp(Configuration.Default)(new ReflectionGrpcJsonBridge[IO](PermissionDeniedTestServiceImpl.bindService())) + val route = AkkaHttp(Configuration.Default)(bridge(PermissionDeniedTestServiceImpl.bindService())) Post(s"/com.avast.grpc.jsonbridge.test.TestService/Add", """ { "a": 1, "b": 2} """) .withHeaders(AkkaHttp.JsonContentType) ~> route ~> check { assertResult(status)(StatusCodes.Forbidden) @@ -55,7 +65,7 @@ class AkkaHttpTest extends FunSuite with ScalatestRouteTest { } test("provides service description") { - val route = AkkaHttp[IO](Configuration.Default)(new ReflectionGrpcJsonBridge[IO](TestServiceImpl.bindService())) + val route = AkkaHttp[IO](Configuration.Default)(bridge(TestServiceImpl.bindService())) Get("/com.avast.grpc.jsonbridge.test.TestService") ~> route ~> check { assertResult(StatusCodes.OK)(status) assertResult("com.avast.grpc.jsonbridge.test.TestService/Add")(responseAs[String]) @@ -63,7 +73,7 @@ class AkkaHttpTest extends FunSuite with ScalatestRouteTest { } test("provides services description") { - val route = AkkaHttp[IO](Configuration.Default)(new ReflectionGrpcJsonBridge[IO](TestServiceImpl.bindService())) + val route = AkkaHttp[IO](Configuration.Default)(bridge(TestServiceImpl.bindService())) Get("/") ~> route ~> check { assertResult(StatusCodes.OK)(status) assertResult("com.avast.grpc.jsonbridge.test.TestService/Add")(responseAs[String]) @@ -72,7 +82,7 @@ class AkkaHttpTest extends FunSuite with ScalatestRouteTest { test("passes headers") { val headerValue = Random.alphanumeric.take(10).mkString("") - val route = AkkaHttp[IO](Configuration.Default)(new ReflectionGrpcJsonBridge[IO](TestServiceImpl.withInterceptor)) + val route = AkkaHttp[IO](Configuration.Default)(bridge(TestServiceImpl.withInterceptor)) val Ok(customHeaderToBeSent, _) = HttpHeader.parse(TestServiceImpl.HeaderName, headerValue) Post("/com.avast.grpc.jsonbridge.test.TestService/Add", """ { "a": 1, "b": 2} """) .withHeaders(AkkaHttp.JsonContentType, customHeaderToBeSent) ~> route ~> check { diff --git a/core/src/main/scala/com/avast/grpc/jsonbridge/ReflectionGrpcJsonBridge.scala b/core/src/main/scala/com/avast/grpc/jsonbridge/ReflectionGrpcJsonBridge.scala index 730547c0..ee215b2d 100644 --- a/core/src/main/scala/com/avast/grpc/jsonbridge/ReflectionGrpcJsonBridge.scala +++ b/core/src/main/scala/com/avast/grpc/jsonbridge/ReflectionGrpcJsonBridge.scala @@ -2,93 +2,105 @@ package com.avast.grpc.jsonbridge import java.lang.reflect.Method -import cats.effect.Async -import cats.syntax.all._ +import cats.effect._ +import cats.implicits._ import com.avast.grpc.jsonbridge.GrpcJsonBridge.GrpcMethodName -import com.google.common.util.concurrent._ +import com.google.common.util.concurrent.{FutureCallback, Futures, ListenableFuture} import com.google.protobuf.util.JsonFormat import com.google.protobuf.{Message, MessageOrBuilder} import com.typesafe.scalalogging.StrictLogging import io.grpc.MethodDescriptor.{MethodType, PrototypeMarshaller} -import io.grpc._ import io.grpc.inprocess.{InProcessChannelBuilder, InProcessServerBuilder} +import io.grpc._ import io.grpc.stub.AbstractStub import scala.collection.JavaConverters._ import scala.concurrent.ExecutionContext -import scala.language.{existentials, higherKinds} +import scala.language.existentials +import scala.language.higherKinds import scala.util.control.NonFatal -class ReflectionGrpcJsonBridge[F[_]](services: ServerServiceDefinition*)(implicit ec: ExecutionContext, F: Async[F]) - extends GrpcJsonBridge[F] - with AutoCloseable - with StrictLogging { - - import com.avast.grpc.jsonbridge.ReflectionGrpcJsonBridge._ +object ReflectionGrpcJsonBridge extends StrictLogging { - def this(grpcServer: io.grpc.Server)(implicit ec: ExecutionContext, F: Async[F]) = this(grpcServer.getImmutableServices.asScala: _*) + // JSON body and headers to a response (fail status or JSON response) + type HandlerFunc[F[_]] = (String, Map[String, String]) => F[Either[Status, String]] - private val inProcessServiceName = s"HttpWrapper-${System.nanoTime()}" + private val parser: JsonFormat.Parser = JsonFormat.parser() - private val inProcessServer = { - val b = InProcessServerBuilder.forName(inProcessServiceName).executor(ec.execute(_)) - services.foreach(b.addService) - b.build().start() + private val printer: JsonFormat.Printer = { + JsonFormat.printer().includingDefaultValueFields().omittingInsignificantWhitespace() } - private val inProcessChannel = InProcessChannelBuilder.forName(inProcessServiceName).executor(ec.execute(_)).build() + def createFromServer[F[_]](ec: ExecutionContext)(grpcServer: io.grpc.Server)(implicit F: Async[F]): Resource[F, GrpcJsonBridge[F]] = { + createFromServices(ec)(grpcServer.getImmutableServices.asScala: _*) + } - override def close(): Unit = { - inProcessChannel.shutdownNow() - inProcessServer.shutdownNow() - () + def createFromServices[F[_]](ec: ExecutionContext)(services: ServerServiceDefinition*)( + implicit F: Async[F]): Resource[F, GrpcJsonBridge[F]] = { + for { + inProcessServiceName <- Resource.liftF(F.delay { s"ReflectionGrpcJsonBridge-${System.nanoTime()}" }) + inProcessServer <- createInProcessServer(ec)(inProcessServiceName, services) + inProcessChannel <- createInProcessChannel(ec)(inProcessServiceName) + handlersPerMethod = inProcessServer.getImmutableServices.asScala + .flatMap(createServiceHandlers(ec)(inProcessChannel)(_)) + .toMap + bridge = createFromHandlers(handlersPerMethod) + } yield bridge } - // map from full method name to a function that invokes that method - protected val handlersPerMethod: Map[String, HandlerFunc[F]] = - inProcessServer.getImmutableServices.asScala - .flatMap(createServiceHandlers(inProcessChannel)(_)) - .toMap + def createFromHandlers[F[_]](handlersPerMethod: Map[String, HandlerFunc[F]])(implicit F: Async[F]): GrpcJsonBridge[F] = { + new GrpcJsonBridge[F] { + override def invoke(methodName: GrpcJsonBridge.GrpcMethodName, + body: String, + headers: Map[String, String]): F[Either[Status, String]] = handlersPerMethod.get(methodName.fullName) match { + case None => F.pure(Left(Status.NOT_FOUND.withDescription(s"Method '$methodName' not found"))) + case Some(handler) => + handler(body, headers) + .recover { + case NonFatal(ex) => + val message = "Error while executing the request" + logger.info(message, ex) + ex match { + case e: StatusException if e.getStatus.getCode == Status.Code.UNKNOWN => + Left(richStatus(Status.INTERNAL, message, e.getStatus.getCause)) + case e: StatusRuntimeException if e.getStatus.getCode == Status.Code.UNKNOWN => + Left(richStatus(Status.INTERNAL, message, e.getStatus.getCause)) + case e: StatusException => + Left(richStatus(e.getStatus, message, e.getStatus.getCause)) + case e: StatusRuntimeException => + Left(richStatus(e.getStatus, message, e.getStatus.getCause)) + case _ => + Left(richStatus(Status.INTERNAL, message, ex)) + } + } + } - override def invoke(methodName: GrpcMethodName, body: String, headers: Map[String, String]): F[Either[Status, String]] = - handlersPerMethod.get(methodName.fullName) match { - case None => F.pure(Left(Status.NOT_FOUND.withDescription(s"Method '$methodName' not found"))) - case Some(handler) => - handler(body, headers) - .recover { - case NonFatal(ex) => - val message = "Error while executing the request" - logger.info(message, ex) - ex match { - case e: StatusException if e.getStatus.getCode == Status.Code.UNKNOWN => - Left(richStatus(Status.INTERNAL, message, e.getStatus.getCause)) - case e: StatusRuntimeException if e.getStatus.getCode == Status.Code.UNKNOWN => - Left(richStatus(Status.INTERNAL, message, e.getStatus.getCause)) - case e: StatusException => - Left(richStatus(e.getStatus, message, e.getStatus.getCause)) - case e: StatusRuntimeException => - Left(richStatus(e.getStatus, message, e.getStatus.getCause)) - case _ => - Left(richStatus(Status.INTERNAL, message, ex)) - } - } + override val methodsNames: Seq[GrpcJsonBridge.GrpcMethodName] = handlersPerMethod.keys.map(m => GrpcMethodName(m)).toSeq + override val servicesNames: Seq[String] = methodsNames.map(_.service).distinct } + } - override val methodsNames: Seq[GrpcMethodName] = handlersPerMethod.keys.map(m => GrpcMethodName(m)).toSeq - override val servicesNames: Seq[String] = methodsNames.map(_.service).distinct - -} - -object ReflectionGrpcJsonBridge extends StrictLogging { - - // JSON body and headers to a response (fail status or JSON response) - private type HandlerFunc[F[_]] = (String, Map[String, String]) => F[Either[Status, String]] - - private val parser: JsonFormat.Parser = JsonFormat.parser() + private def createInProcessServer[F[_]](ec: ExecutionContext)(inProcessServiceName: String, services: Seq[ServerServiceDefinition])( + implicit F: Sync[F]): Resource[F, Server] = + Resource.make { + F.delay { + val b = InProcessServerBuilder.forName(inProcessServiceName).executor(ec.execute(_)) + services.foreach(b.addService) + b.build().start() + } + } { s => + F.delay { s.shutdown().awaitTermination() } + } - private val printer: JsonFormat.Printer = { - JsonFormat.printer().includingDefaultValueFields().omittingInsignificantWhitespace() - } + private def createInProcessChannel[F[_]](ec: ExecutionContext)(inProcessServiceName: String)( + implicit F: Sync[F]): Resource[F, ManagedChannel] = + Resource.make { + F.delay { + InProcessChannelBuilder.forName(inProcessServiceName).executor(ec.execute(_)).build() + } + } { c => + F.delay { c.shutdown() } + } private def createFutureStubCtor(sd: ServiceDescriptor, inProcessChannel: Channel): () => AbstractStub[_] = { val serviceGeneratedClass = Class.forName { @@ -100,21 +112,21 @@ object ReflectionGrpcJsonBridge extends StrictLogging { method.invoke(null, inProcessChannel).asInstanceOf[AbstractStub[_]] } - private def createServiceHandlers[F[_]](inProcessChannel: ManagedChannel)( - ssd: ServerServiceDefinition)(implicit ec: ExecutionContext, F: Async[F]): Map[String, HandlerFunc[F]] = { + private def createServiceHandlers[F[_]](ec: ExecutionContext)(inProcessChannel: ManagedChannel)(ssd: ServerServiceDefinition)( + implicit F: Async[F]): Map[String, HandlerFunc[F]] = { val futureStubCtor = createFutureStubCtor(ssd.getServiceDescriptor, inProcessChannel) ssd.getMethods.asScala .filter(isSupportedMethod) - .map(createHandler(futureStubCtor)(_)) + .map(createHandler(ec)(futureStubCtor)(_)) .toMap } - private def createHandler[F[_]](futureStubCtor: () => AbstractStub[_])( - method: ServerMethodDefinition[_, _])(implicit ec: ExecutionContext, F: Async[F]): (String, HandlerFunc[F]) = { + private def createHandler[F[_]](ec: ExecutionContext)(futureStubCtor: () => AbstractStub[_])(method: ServerMethodDefinition[_, _])( + implicit F: Async[F]): (String, HandlerFunc[F]) = { val requestMessagePrototype = getRequestMessagePrototype(method) val javaMethod = futureStubCtor().getClass .getDeclaredMethod(getJavaMethodName(method), requestMessagePrototype.getClass) - val execute = executeRequest[F](futureStubCtor, javaMethod) _ + val execute = executeRequest[F](ec)(futureStubCtor, javaMethod) _ val handler: HandlerFunc[F] = (json, headers) => { parseRequest(json, requestMessagePrototype) match { @@ -125,23 +137,23 @@ object ReflectionGrpcJsonBridge extends StrictLogging { (method.getMethodDescriptor.getFullMethodName, handler) } - private def executeRequest[F[_]](futureStubCtor: () => AbstractStub[_], method: Method)(req: Message, headers: Map[String, String])( - implicit ec: ExecutionContext, - F: Async[F]): F[MessageOrBuilder] = { + private def executeRequest[F[_]](ec: ExecutionContext)(futureStubCtor: () => AbstractStub[_], method: Method)( + req: Message, + headers: Map[String, String])(implicit F: Async[F]): F[MessageOrBuilder] = { val metaData = { val md = new Metadata() headers.foreach { case (k, v) => md.put(Metadata.Key.of(k, Metadata.ASCII_STRING_MARSHALLER), v) } md } val stubWithHeaders = JavaGenericHelper.attachHeaders(futureStubCtor(), metaData) - fromListenableFuture(F.delay { + fromListenableFuture(ec)(F.delay { method.invoke(stubWithHeaders, req).asInstanceOf[ListenableFuture[MessageOrBuilder]] }) } private def isSupportedMethod(d: ServerMethodDefinition[_, _]): Boolean = d.getMethodDescriptor.getType == MethodType.UNARY - private def fromListenableFuture[F[_], A](flf: F[ListenableFuture[A]])(implicit ec: ExecutionContext, F: Async[F]): F[A] = flf.flatMap { + private def fromListenableFuture[F[_], A](ec: ExecutionContext)(flf: F[ListenableFuture[A]])(implicit F: Async[F]): F[A] = flf.flatMap { lf => F.async { cb => Futures.addCallback(lf, new FutureCallback[A] { @@ -183,5 +195,4 @@ object ReflectionGrpcJsonBridge extends StrictLogging { Left(richStatus(Status.INVALID_ARGUMENT, message, ex)) } } - } diff --git a/core/src/test/scala/com/avast/grpc/jsonbridge/ReflectionGrpcJsonBridgeTest.scala b/core/src/test/scala/com/avast/grpc/jsonbridge/ReflectionGrpcJsonBridgeTest.scala index c4038feb..605fe525 100644 --- a/core/src/test/scala/com/avast/grpc/jsonbridge/ReflectionGrpcJsonBridgeTest.scala +++ b/core/src/test/scala/com/avast/grpc/jsonbridge/ReflectionGrpcJsonBridgeTest.scala @@ -15,12 +15,12 @@ class ReflectionGrpcJsonBridgeTest extends fixture.FlatSpec with Matchers { override protected def withFixture(test: OneArgTest): Outcome = { val channelName = InProcessServerBuilder.generateName val server = InProcessServerBuilder.forName(channelName).addService(new TestServiceImpl()).build - val bridge = new ReflectionGrpcJsonBridge[IO](server) + val (bridge, close) = ReflectionGrpcJsonBridge.createFromServer[IO](global)(server).allocated.unsafeRunSync() try { test(FixtureParam(bridge)) } finally { server.shutdownNow() - bridge.close() + close.unsafeRunSync() } } diff --git a/http4s/src/test/scala/com/avast/grpc/jsonbrige/http4s/Http4sTest.scala b/http4s/src/test/scala/com/avast/grpc/jsonbrige/http4s/Http4sTest.scala index 7ec61ac9..31e2be3c 100644 --- a/http4s/src/test/scala/com/avast/grpc/jsonbrige/http4s/Http4sTest.scala +++ b/http4s/src/test/scala/com/avast/grpc/jsonbrige/http4s/Http4sTest.scala @@ -3,18 +3,28 @@ package com.avast.grpc.jsonbrige.http4s import cats.data.NonEmptyList import cats.effect.IO import com.avast.grpc.jsonbridge._ +import io.grpc.ServerServiceDefinition import org.http4s.headers.{`Content-Length`, `Content-Type`} import org.http4s.{Charset, Header, Headers, MediaType, Method, Request, Uri} import org.scalatest.FunSuite import org.scalatest.concurrent.ScalaFutures +import scala.concurrent.ExecutionContext import scala.concurrent.ExecutionContext.Implicits.global import scala.util.Random class Http4sTest extends FunSuite with ScalaFutures { + val ec: ExecutionContext = implicitly[ExecutionContext] + def bridge(ssd: ServerServiceDefinition): GrpcJsonBridge[IO] = + ReflectionGrpcJsonBridge + .createFromServices[IO](ec)(ssd) + .allocated + .unsafeRunSync() + ._1 + test("basic") { - val service = Http4s(Configuration.Default)(new ReflectionGrpcJsonBridge[IO](TestServiceImpl.bindService())) + val service = Http4s(Configuration.Default)(bridge(TestServiceImpl.bindService())) val Some(response) = service .apply( @@ -40,7 +50,7 @@ class Http4sTest extends FunSuite with ScalaFutures { test("path prefix") { val configuration = Configuration.Default.copy(pathPrefix = Some(NonEmptyList.of("abc", "def"))) - val service = Http4s(configuration)(new ReflectionGrpcJsonBridge[IO](TestServiceImpl.bindService())) + val service = Http4s(configuration)(bridge(TestServiceImpl.bindService())) val Some(response) = service .apply( Request[IO](method = Method.POST, uri = Uri.fromString("/abc/def/com.avast.grpc.jsonbridge.test.TestService/Add").getOrElse(fail())) @@ -62,7 +72,7 @@ class Http4sTest extends FunSuite with ScalaFutures { } test("bad request after wrong request") { - val service = Http4s(Configuration.Default)(new ReflectionGrpcJsonBridge[IO](TestServiceImpl.bindService())) + val service = Http4s(Configuration.Default)(bridge(TestServiceImpl.bindService())) { // empty body val Some(response) = service @@ -92,7 +102,7 @@ class Http4sTest extends FunSuite with ScalaFutures { } test("propagate user-specified status") { - val service = Http4s(Configuration.Default)(new ReflectionGrpcJsonBridge[IO](PermissionDeniedTestServiceImpl.bindService())) + val service = Http4s(Configuration.Default)(bridge(PermissionDeniedTestServiceImpl.bindService())) val Some(response) = service .apply( @@ -108,7 +118,7 @@ class Http4sTest extends FunSuite with ScalaFutures { } test("provides service info") { - val service = Http4s(Configuration.Default)(new ReflectionGrpcJsonBridge[IO](TestServiceImpl.bindService())) + val service = Http4s(Configuration.Default)(bridge(TestServiceImpl.bindService())) val Some(response) = service .apply( @@ -123,7 +133,7 @@ class Http4sTest extends FunSuite with ScalaFutures { } test("provides services info") { - val service = Http4s(Configuration.Default)(new ReflectionGrpcJsonBridge[IO](TestServiceImpl.bindService())) + val service = Http4s(Configuration.Default)(bridge(TestServiceImpl.bindService())) val Some(response) = service .apply( @@ -138,7 +148,7 @@ class Http4sTest extends FunSuite with ScalaFutures { } test("passes user headers") { - val service = Http4s(Configuration.Default)(new ReflectionGrpcJsonBridge[IO](TestServiceImpl.withInterceptor)) + val service = Http4s(Configuration.Default)(bridge(TestServiceImpl.withInterceptor)) val headerValue = Random.alphanumeric.take(10).mkString("")