Skip to content

Commit

Permalink
Merge pull request #36 from sideeffffect/resource
Browse files Browse the repository at this point in the history
create `GrpcJsonBridge` as a `Resource`
  • Loading branch information
sideeffffect authored May 27, 2019
2 parents 946dddc + 4325fa0 commit 0b03654
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 93 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ import com.avast.grpc.jsonbridge.ReflectionGrpcJsonBridge

// for whole server
val grpcServer: io.grpc.Server = ???
val bridge = new ReflectionGrpcJsonBridge[Task](grpcServer)
val bridge = ReflectionGrpcJsonBridge.createFromServer[Task](grpcServer)

// or for selected services
val s1: ServerServiceDefinition = ???
val s2: ServerServiceDefinition = ???
val anotherBridge = new ReflectionGrpcJsonBridge[Task](s1, s2)
val anotherBridge = ReflectionGrpcJsonBridge.createFromServices[Task](s1, s2)

// call a method manually, with a header specified
val jsonResponse = bridge.invoke("com.avast.grpc.jsonbridge.test.TestService/Add", """ { "a": 1, "b": 2} """, Map("My-Header" -> "value"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,24 @@ import akka.http.scaladsl.testkit.ScalatestRouteTest
import cats.data.NonEmptyList
import cats.effect.IO
import com.avast.grpc.jsonbridge._
import io.grpc.ServerServiceDefinition
import org.scalatest.FunSuite

import scala.concurrent.ExecutionContext
import scala.util.Random

class AkkaHttpTest extends FunSuite with ScalatestRouteTest {

val ec: ExecutionContext = implicitly[ExecutionContext]
def bridge(ssd: ServerServiceDefinition): GrpcJsonBridge[IO] =
ReflectionGrpcJsonBridge
.createFromServices[IO](ec)(ssd)
.allocated
.unsafeRunSync()
._1

test("basic") {
val route = AkkaHttp[IO](Configuration.Default)(new ReflectionGrpcJsonBridge[IO](TestServiceImpl.bindService()))
val route = AkkaHttp[IO](Configuration.Default)(bridge(TestServiceImpl.bindService()))
Post("/com.avast.grpc.jsonbridge.test.TestService/Add", """ { "a": 1, "b": 2} """)
.withHeaders(AkkaHttp.JsonContentType) ~> route ~> check {
assertResult(StatusCodes.OK)(status)
Expand All @@ -25,7 +35,7 @@ class AkkaHttpTest extends FunSuite with ScalatestRouteTest {

test("with path prefix") {
val configuration = Configuration.Default.copy(pathPrefix = Some(NonEmptyList.of("abc", "def")))
val route = AkkaHttp[IO](configuration)(new ReflectionGrpcJsonBridge[IO](TestServiceImpl.bindService()))
val route = AkkaHttp[IO](configuration)(bridge(TestServiceImpl.bindService()))
Post("/abc/def/com.avast.grpc.jsonbridge.test.TestService/Add", """ { "a": 1, "b": 2} """)
.withHeaders(AkkaHttp.JsonContentType) ~> route ~> check {
assertResult(StatusCodes.OK)(status)
Expand All @@ -34,7 +44,7 @@ class AkkaHttpTest extends FunSuite with ScalatestRouteTest {
}

test("bad request after wrong request") {
val route = AkkaHttp[IO](Configuration.Default)(new ReflectionGrpcJsonBridge[IO](TestServiceImpl.bindService()))
val route = AkkaHttp[IO](Configuration.Default)(bridge(TestServiceImpl.bindService()))
// empty body
Post("/com.avast.grpc.jsonbridge.test.TestService/Add", "")
.withHeaders(AkkaHttp.JsonContentType) ~> route ~> check {
Expand All @@ -47,23 +57,23 @@ class AkkaHttpTest extends FunSuite with ScalatestRouteTest {
}

test("propagates user-specified status") {
val route = AkkaHttp(Configuration.Default)(new ReflectionGrpcJsonBridge[IO](PermissionDeniedTestServiceImpl.bindService()))
val route = AkkaHttp(Configuration.Default)(bridge(PermissionDeniedTestServiceImpl.bindService()))
Post(s"/com.avast.grpc.jsonbridge.test.TestService/Add", """ { "a": 1, "b": 2} """)
.withHeaders(AkkaHttp.JsonContentType) ~> route ~> check {
assertResult(status)(StatusCodes.Forbidden)
}
}

test("provides service description") {
val route = AkkaHttp[IO](Configuration.Default)(new ReflectionGrpcJsonBridge[IO](TestServiceImpl.bindService()))
val route = AkkaHttp[IO](Configuration.Default)(bridge(TestServiceImpl.bindService()))
Get("/com.avast.grpc.jsonbridge.test.TestService") ~> route ~> check {
assertResult(StatusCodes.OK)(status)
assertResult("com.avast.grpc.jsonbridge.test.TestService/Add")(responseAs[String])
}
}

test("provides services description") {
val route = AkkaHttp[IO](Configuration.Default)(new ReflectionGrpcJsonBridge[IO](TestServiceImpl.bindService()))
val route = AkkaHttp[IO](Configuration.Default)(bridge(TestServiceImpl.bindService()))
Get("/") ~> route ~> check {
assertResult(StatusCodes.OK)(status)
assertResult("com.avast.grpc.jsonbridge.test.TestService/Add")(responseAs[String])
Expand All @@ -72,7 +82,7 @@ class AkkaHttpTest extends FunSuite with ScalatestRouteTest {

test("passes headers") {
val headerValue = Random.alphanumeric.take(10).mkString("")
val route = AkkaHttp[IO](Configuration.Default)(new ReflectionGrpcJsonBridge[IO](TestServiceImpl.withInterceptor))
val route = AkkaHttp[IO](Configuration.Default)(bridge(TestServiceImpl.withInterceptor))
val Ok(customHeaderToBeSent, _) = HttpHeader.parse(TestServiceImpl.HeaderName, headerValue)
Post("/com.avast.grpc.jsonbridge.test.TestService/Add", """ { "a": 1, "b": 2} """)
.withHeaders(AkkaHttp.JsonContentType, customHeaderToBeSent) ~> route ~> check {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,93 +2,105 @@ package com.avast.grpc.jsonbridge

import java.lang.reflect.Method

import cats.effect.Async
import cats.syntax.all._
import cats.effect._
import cats.implicits._
import com.avast.grpc.jsonbridge.GrpcJsonBridge.GrpcMethodName
import com.google.common.util.concurrent._
import com.google.common.util.concurrent.{FutureCallback, Futures, ListenableFuture}
import com.google.protobuf.util.JsonFormat
import com.google.protobuf.{Message, MessageOrBuilder}
import com.typesafe.scalalogging.StrictLogging
import io.grpc.MethodDescriptor.{MethodType, PrototypeMarshaller}
import io.grpc._
import io.grpc.inprocess.{InProcessChannelBuilder, InProcessServerBuilder}
import io.grpc._
import io.grpc.stub.AbstractStub

import scala.collection.JavaConverters._
import scala.concurrent.ExecutionContext
import scala.language.{existentials, higherKinds}
import scala.language.existentials
import scala.language.higherKinds
import scala.util.control.NonFatal

class ReflectionGrpcJsonBridge[F[_]](services: ServerServiceDefinition*)(implicit ec: ExecutionContext, F: Async[F])
extends GrpcJsonBridge[F]
with AutoCloseable
with StrictLogging {

import com.avast.grpc.jsonbridge.ReflectionGrpcJsonBridge._
object ReflectionGrpcJsonBridge extends StrictLogging {

def this(grpcServer: io.grpc.Server)(implicit ec: ExecutionContext, F: Async[F]) = this(grpcServer.getImmutableServices.asScala: _*)
// JSON body and headers to a response (fail status or JSON response)
type HandlerFunc[F[_]] = (String, Map[String, String]) => F[Either[Status, String]]

private val inProcessServiceName = s"HttpWrapper-${System.nanoTime()}"
private val parser: JsonFormat.Parser = JsonFormat.parser()

private val inProcessServer = {
val b = InProcessServerBuilder.forName(inProcessServiceName).executor(ec.execute(_))
services.foreach(b.addService)
b.build().start()
private val printer: JsonFormat.Printer = {
JsonFormat.printer().includingDefaultValueFields().omittingInsignificantWhitespace()
}

private val inProcessChannel = InProcessChannelBuilder.forName(inProcessServiceName).executor(ec.execute(_)).build()
def createFromServer[F[_]](ec: ExecutionContext)(grpcServer: io.grpc.Server)(implicit F: Async[F]): Resource[F, GrpcJsonBridge[F]] = {
createFromServices(ec)(grpcServer.getImmutableServices.asScala: _*)
}

override def close(): Unit = {
inProcessChannel.shutdownNow()
inProcessServer.shutdownNow()
()
def createFromServices[F[_]](ec: ExecutionContext)(services: ServerServiceDefinition*)(
implicit F: Async[F]): Resource[F, GrpcJsonBridge[F]] = {
for {
inProcessServiceName <- Resource.liftF(F.delay { s"ReflectionGrpcJsonBridge-${System.nanoTime()}" })
inProcessServer <- createInProcessServer(ec)(inProcessServiceName, services)
inProcessChannel <- createInProcessChannel(ec)(inProcessServiceName)
handlersPerMethod = inProcessServer.getImmutableServices.asScala
.flatMap(createServiceHandlers(ec)(inProcessChannel)(_))
.toMap
bridge = createFromHandlers(handlersPerMethod)
} yield bridge
}

// map from full method name to a function that invokes that method
protected val handlersPerMethod: Map[String, HandlerFunc[F]] =
inProcessServer.getImmutableServices.asScala
.flatMap(createServiceHandlers(inProcessChannel)(_))
.toMap
def createFromHandlers[F[_]](handlersPerMethod: Map[String, HandlerFunc[F]])(implicit F: Async[F]): GrpcJsonBridge[F] = {
new GrpcJsonBridge[F] {
override def invoke(methodName: GrpcJsonBridge.GrpcMethodName,
body: String,
headers: Map[String, String]): F[Either[Status, String]] = handlersPerMethod.get(methodName.fullName) match {
case None => F.pure(Left(Status.NOT_FOUND.withDescription(s"Method '$methodName' not found")))
case Some(handler) =>
handler(body, headers)
.recover {
case NonFatal(ex) =>
val message = "Error while executing the request"
logger.info(message, ex)
ex match {
case e: StatusException if e.getStatus.getCode == Status.Code.UNKNOWN =>
Left(richStatus(Status.INTERNAL, message, e.getStatus.getCause))
case e: StatusRuntimeException if e.getStatus.getCode == Status.Code.UNKNOWN =>
Left(richStatus(Status.INTERNAL, message, e.getStatus.getCause))
case e: StatusException =>
Left(richStatus(e.getStatus, message, e.getStatus.getCause))
case e: StatusRuntimeException =>
Left(richStatus(e.getStatus, message, e.getStatus.getCause))
case _ =>
Left(richStatus(Status.INTERNAL, message, ex))
}
}
}

override def invoke(methodName: GrpcMethodName, body: String, headers: Map[String, String]): F[Either[Status, String]] =
handlersPerMethod.get(methodName.fullName) match {
case None => F.pure(Left(Status.NOT_FOUND.withDescription(s"Method '$methodName' not found")))
case Some(handler) =>
handler(body, headers)
.recover {
case NonFatal(ex) =>
val message = "Error while executing the request"
logger.info(message, ex)
ex match {
case e: StatusException if e.getStatus.getCode == Status.Code.UNKNOWN =>
Left(richStatus(Status.INTERNAL, message, e.getStatus.getCause))
case e: StatusRuntimeException if e.getStatus.getCode == Status.Code.UNKNOWN =>
Left(richStatus(Status.INTERNAL, message, e.getStatus.getCause))
case e: StatusException =>
Left(richStatus(e.getStatus, message, e.getStatus.getCause))
case e: StatusRuntimeException =>
Left(richStatus(e.getStatus, message, e.getStatus.getCause))
case _ =>
Left(richStatus(Status.INTERNAL, message, ex))
}
}
override val methodsNames: Seq[GrpcJsonBridge.GrpcMethodName] = handlersPerMethod.keys.map(m => GrpcMethodName(m)).toSeq
override val servicesNames: Seq[String] = methodsNames.map(_.service).distinct
}
}

override val methodsNames: Seq[GrpcMethodName] = handlersPerMethod.keys.map(m => GrpcMethodName(m)).toSeq
override val servicesNames: Seq[String] = methodsNames.map(_.service).distinct

}

object ReflectionGrpcJsonBridge extends StrictLogging {

// JSON body and headers to a response (fail status or JSON response)
private type HandlerFunc[F[_]] = (String, Map[String, String]) => F[Either[Status, String]]

private val parser: JsonFormat.Parser = JsonFormat.parser()
private def createInProcessServer[F[_]](ec: ExecutionContext)(inProcessServiceName: String, services: Seq[ServerServiceDefinition])(
implicit F: Sync[F]): Resource[F, Server] =
Resource.make {
F.delay {
val b = InProcessServerBuilder.forName(inProcessServiceName).executor(ec.execute(_))
services.foreach(b.addService)
b.build().start()
}
} { s =>
F.delay { s.shutdown().awaitTermination() }
}

private val printer: JsonFormat.Printer = {
JsonFormat.printer().includingDefaultValueFields().omittingInsignificantWhitespace()
}
private def createInProcessChannel[F[_]](ec: ExecutionContext)(inProcessServiceName: String)(
implicit F: Sync[F]): Resource[F, ManagedChannel] =
Resource.make {
F.delay {
InProcessChannelBuilder.forName(inProcessServiceName).executor(ec.execute(_)).build()
}
} { c =>
F.delay { c.shutdown() }
}

private def createFutureStubCtor(sd: ServiceDescriptor, inProcessChannel: Channel): () => AbstractStub[_] = {
val serviceGeneratedClass = Class.forName {
Expand All @@ -100,21 +112,21 @@ object ReflectionGrpcJsonBridge extends StrictLogging {
method.invoke(null, inProcessChannel).asInstanceOf[AbstractStub[_]]
}

private def createServiceHandlers[F[_]](inProcessChannel: ManagedChannel)(
ssd: ServerServiceDefinition)(implicit ec: ExecutionContext, F: Async[F]): Map[String, HandlerFunc[F]] = {
private def createServiceHandlers[F[_]](ec: ExecutionContext)(inProcessChannel: ManagedChannel)(ssd: ServerServiceDefinition)(
implicit F: Async[F]): Map[String, HandlerFunc[F]] = {
val futureStubCtor = createFutureStubCtor(ssd.getServiceDescriptor, inProcessChannel)
ssd.getMethods.asScala
.filter(isSupportedMethod)
.map(createHandler(futureStubCtor)(_))
.map(createHandler(ec)(futureStubCtor)(_))
.toMap
}

private def createHandler[F[_]](futureStubCtor: () => AbstractStub[_])(
method: ServerMethodDefinition[_, _])(implicit ec: ExecutionContext, F: Async[F]): (String, HandlerFunc[F]) = {
private def createHandler[F[_]](ec: ExecutionContext)(futureStubCtor: () => AbstractStub[_])(method: ServerMethodDefinition[_, _])(
implicit F: Async[F]): (String, HandlerFunc[F]) = {
val requestMessagePrototype = getRequestMessagePrototype(method)
val javaMethod = futureStubCtor().getClass
.getDeclaredMethod(getJavaMethodName(method), requestMessagePrototype.getClass)
val execute = executeRequest[F](futureStubCtor, javaMethod) _
val execute = executeRequest[F](ec)(futureStubCtor, javaMethod) _

val handler: HandlerFunc[F] = (json, headers) => {
parseRequest(json, requestMessagePrototype) match {
Expand All @@ -125,23 +137,23 @@ object ReflectionGrpcJsonBridge extends StrictLogging {
(method.getMethodDescriptor.getFullMethodName, handler)
}

private def executeRequest[F[_]](futureStubCtor: () => AbstractStub[_], method: Method)(req: Message, headers: Map[String, String])(
implicit ec: ExecutionContext,
F: Async[F]): F[MessageOrBuilder] = {
private def executeRequest[F[_]](ec: ExecutionContext)(futureStubCtor: () => AbstractStub[_], method: Method)(
req: Message,
headers: Map[String, String])(implicit F: Async[F]): F[MessageOrBuilder] = {
val metaData = {
val md = new Metadata()
headers.foreach { case (k, v) => md.put(Metadata.Key.of(k, Metadata.ASCII_STRING_MARSHALLER), v) }
md
}
val stubWithHeaders = JavaGenericHelper.attachHeaders(futureStubCtor(), metaData)
fromListenableFuture(F.delay {
fromListenableFuture(ec)(F.delay {
method.invoke(stubWithHeaders, req).asInstanceOf[ListenableFuture[MessageOrBuilder]]
})
}

private def isSupportedMethod(d: ServerMethodDefinition[_, _]): Boolean = d.getMethodDescriptor.getType == MethodType.UNARY

private def fromListenableFuture[F[_], A](flf: F[ListenableFuture[A]])(implicit ec: ExecutionContext, F: Async[F]): F[A] = flf.flatMap {
private def fromListenableFuture[F[_], A](ec: ExecutionContext)(flf: F[ListenableFuture[A]])(implicit F: Async[F]): F[A] = flf.flatMap {
lf =>
F.async { cb =>
Futures.addCallback(lf, new FutureCallback[A] {
Expand Down Expand Up @@ -183,5 +195,4 @@ object ReflectionGrpcJsonBridge extends StrictLogging {
Left(richStatus(Status.INVALID_ARGUMENT, message, ex))
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ class ReflectionGrpcJsonBridgeTest extends fixture.FlatSpec with Matchers {
override protected def withFixture(test: OneArgTest): Outcome = {
val channelName = InProcessServerBuilder.generateName
val server = InProcessServerBuilder.forName(channelName).addService(new TestServiceImpl()).build
val bridge = new ReflectionGrpcJsonBridge[IO](server)
val (bridge, close) = ReflectionGrpcJsonBridge.createFromServer[IO](global)(server).allocated.unsafeRunSync()
try {
test(FixtureParam(bridge))
} finally {
server.shutdownNow()
bridge.close()
close.unsafeRunSync()
}
}

Expand Down
Loading

0 comments on commit 0b03654

Please sign in to comment.