Skip to content

Commit

Permalink
fix: remove thread local based wiring (#1817)
Browse files Browse the repository at this point in the history
* fix: remove thread local based wiring

* added missing test for ill-defined Workflow

* throwing exceptions for real

* fix contructor validation
  • Loading branch information
octonato authored Oct 26, 2023
1 parent 0d7c344 commit 102f95c
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,18 @@

package kalix.spring.impl

import java.lang.reflect.Constructor
import java.lang.reflect.Modifier
import java.lang.reflect.ParameterizedType

import scala.concurrent.Future
import scala.jdk.CollectionConverters.CollectionHasAsScala
import scala.jdk.FutureConverters.CompletionStageOps
import scala.jdk.OptionConverters.RichOption

import akka.Done
import com.typesafe.config.Config
import kalix.javasdk.Context
import kalix.javasdk.Kalix
import kalix.javasdk.action.Action
import kalix.javasdk.action.ActionCreationContext
Expand Down Expand Up @@ -55,7 +65,6 @@ import kalix.spring.WebClientProvider
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.springframework.beans.factory.BeanCreationException
import org.springframework.beans.factory.FactoryBean
import org.springframework.beans.factory.config.BeanDefinition
import org.springframework.beans.factory.config.ConfigurableBeanFactory
import org.springframework.beans.factory.support.AbstractBeanDefinition
Expand All @@ -67,15 +76,6 @@ import org.springframework.core.`type`.classreading.MetadataReader
import org.springframework.core.`type`.classreading.MetadataReaderFactory
import org.springframework.core.`type`.filter.TypeFilter

import java.lang.reflect.Modifier
import java.lang.reflect.ParameterizedType
import scala.concurrent.Future
import scala.jdk.CollectionConverters.CollectionHasAsScala
import scala.jdk.FutureConverters.CompletionStageOps
import scala.jdk.OptionConverters.RichOption
import scala.jdk.OptionConverters._
import scala.reflect.ClassTag

object KalixSpringApplication {

val kalixComponents: Seq[Class[_]] =
Expand Down Expand Up @@ -115,7 +115,7 @@ object KalixSpringApplication {
* The enhanced variant doesn't contain all the annotations, but only the SpringBootApplication one. Therefore, we
* need to lookup for the original one. We need it to find the default ACL annotation.
*/
class MainClassProvider(cglibMain: Class[_]) extends ClassPathScanningCandidateComponentProvider {
private class MainClassProvider(cglibMain: Class[_]) extends ClassPathScanningCandidateComponentProvider {

private object OriginalMainClassFilter extends TypeFilter {
override def `match`(metadataReader: MetadataReader, metadataReaderFactory: MetadataReaderFactory): Boolean = {
Expand Down Expand Up @@ -149,7 +149,7 @@ object KalixSpringApplication {
* This class will do exactly this. It find them and return tweaked BeanDefinitions (eg :prototype scope and autowired
* by constructor)
*/
class KalixComponentProvider(cglibMain: Class[_]) extends ClassPathScanningCandidateComponentProvider {
private class KalixComponentProvider(cglibMain: Class[_]) extends ClassPathScanningCandidateComponentProvider {

private object KalixComponentTypeFilter extends TypeFilter {
override def `match`(metadataReader: MetadataReader, metadataReaderFactory: MetadataReaderFactory): Boolean = {
Expand Down Expand Up @@ -186,66 +186,6 @@ object KalixSpringApplication {
}
}

abstract class ThreadLocalFactoryBean[T: ClassTag] extends FactoryBean[T] {
val threadLocal = new ThreadLocal[T]

def set(value: T) = threadLocal.set(value)

override def getObject: T = threadLocal.get()

override def getObjectType: Class[_] = implicitly[ClassTag[T]].runtimeClass
}

object ActionCreationContextFactoryBean extends ThreadLocalFactoryBean[ActionCreationContext] {
// ActionCreationContext is a singleton, so strictly speaking this could return 'true'
// However, we still need the ThreadLocal hack to let Spring have access to it.
// Also, we don't want to give direct access to it because we want to provide different ActionCreationContext impl
// depending if it's used in prod code or during tests.
override def isSingleton: Boolean = false
}

object EventSourcedEntityContextFactoryBean extends ThreadLocalFactoryBean[EventSourcedEntityContext] {
override def isSingleton: Boolean = false // never!!
}

object WorkflowContextFactoryBean extends ThreadLocalFactoryBean[WorkflowContext] {
override def isSingleton: Boolean = false // never!!
}

object ValueEntityContextFactoryBean extends ThreadLocalFactoryBean[ValueEntityContext] {
override def isSingleton: Boolean = false // never!!
}

object ViewCreationContextFactoryBean extends ThreadLocalFactoryBean[ViewCreationContext] {
override def isSingleton: Boolean = false // never!!
}

object KalixClientFactoryBean extends ThreadLocalFactoryBean[KalixClient] {
override def isSingleton: Boolean = true // yes, we only need one

override def getObject: KalixClient =
if (threadLocal.get() != null) threadLocal.get()
else
throw new BeanCreationException("KalixClient can only be injected in Kalix Actions and Workflows.")
}

object ComponentClientFactoryBean extends ThreadLocalFactoryBean[ComponentClient] {
override def isSingleton: Boolean = true // yes, we only need one

override def getObject: ComponentClient =
if (threadLocal.get() != null) threadLocal.get()
else
throw new BeanCreationException("ComponentClient can only be injected in Kalix Actions and Workflows.")
}

object WebClientProviderFactoryBean extends ThreadLocalFactoryBean[WebClientProvider] {
override def isSingleton: Boolean = true // yes, we only need one

override def getObject: WebClientProvider =
if (threadLocal.get() != null) threadLocal.get()
else
throw new BeanCreationException("WebClientProvider can only be injected in Kalix Actions and Workflows.")
}
}

case class KalixSpringApplication(applicationContext: ApplicationContext, config: Config) {
Expand All @@ -260,15 +200,6 @@ case class KalixSpringApplication(applicationContext: ApplicationContext, config

private val kalixBeanFactory = new DefaultListableBeanFactory(applicationContext)

kalixBeanFactory.registerSingleton("actionCreationContextFactoryBean", ActionCreationContextFactoryBean)
kalixBeanFactory.registerSingleton("eventSourcedEntityContext", EventSourcedEntityContextFactoryBean)
kalixBeanFactory.registerSingleton("workflowEntityContext", WorkflowContextFactoryBean)
kalixBeanFactory.registerSingleton("valueEntityContext", ValueEntityContextFactoryBean)
kalixBeanFactory.registerSingleton("viewCreationContext", ViewCreationContextFactoryBean)
kalixBeanFactory.registerSingleton("kalixClient", KalixClientFactoryBean)
kalixBeanFactory.registerSingleton("componentClient", ComponentClientFactoryBean)
kalixBeanFactory.registerSingleton("webClientProvider", WebClientProviderFactoryBean)

// there should be only one class annotated with SpringBootApplication in the applicationContext
private val cglibEnhanceMainClass =
applicationContext.getBeansWithAnnotation(classOf[SpringBootApplication]).values().asScala.head
Expand All @@ -284,7 +215,7 @@ case class KalixSpringApplication(applicationContext: ApplicationContext, config
provider.setEnvironment(applicationContext.getEnvironment) //use the same environment to get access to properties

// load all Kalix components found in the classpath
val classBeanMap =
private val classBeanMap =
provider.findKalixComponents.map { bean =>
// here we need to load the components using the same loader as the Main class
// this is needed to have it loaded in the RestartClassLoader when using auto-reload
Expand All @@ -293,7 +224,7 @@ case class KalixSpringApplication(applicationContext: ApplicationContext, config
}.toMap

// each loaded class needs to be validated before registration
val validation =
private val validation =
classBeanMap.keySet
.foldLeft(Valid: Validation) { case (validations, cls) =>
validations ++ Validations.validate(cls)
Expand Down Expand Up @@ -368,88 +299,101 @@ case class KalixSpringApplication(applicationContext: ApplicationContext, config

def port: Int = kalixRunner.configuration.userFunctionPort

/* Each component may have a creation context passed to its constructor.
* This method checks if there is a constructor in `clz` that receives a `context`.
*/
private def hasContextConstructor(clz: Class[_], contextType: Class[_]): Boolean =
clz.getConstructors.exists { ctor =>
ctor.getParameterTypes.contains(contextType)
}

private def actionProvider[A <: Action](clz: Class[A]): ActionProvider[A] =
ReflectiveActionProvider.of(
clz,
messageCodec,
context => {
if (hasContextConstructor(clz, classOf[ActionCreationContext]))
ActionCreationContextFactoryBean.set(context)

val webClientProviderHolder = WebClientProviderHolder(context.materializer().system)

setKalixClient(clz, webClientProviderHolder)
setComponentClient(clz, webClientProviderHolder)
private def webClientProvider(context: Context) = {
val webClientProviderHolder = WebClientProviderHolder(context.materializer().system)
webClientProviderHolder.webClientProvider
}

if (hasContextConstructor(clz, classOf[WebClientProvider])) {
val webClientProvider = webClientProviderHolder.webClientProvider
WebClientProviderFactoryBean.set(webClientProvider)
}
/**
* Create an instance of `clz` using the mappings defined in `partial`. Each component provider should define what are
* the acceptable dependencies in the partial function.
*
* If the partial function doesn't match, it will try to lookup in the Spring applicationContext.
*/
private def wiredInstance[T](clz: Class[T])(partial: PartialFunction[Class[_], Any]): T = {
// only one constructor allowed
require(clz.getDeclaredConstructors.length == 1, s"Class [${clz.getSimpleName}] must have only one constructor.")
wiredInstance(clz.getDeclaredConstructors.head.asInstanceOf[Constructor[T]])(partial)
}

kalixBeanFactory.getBean(clz)
})
/**
* Create an instance using the passed `constructor` and the mappings defined in `partial`.
*
* Each component provider should define what are the acceptable dependencies in the partial function.
*
* If the partial function doesn't match, it will try to lookup in the Spring applicationContext.
*/
private def wiredInstance[T](constructor: Constructor[T])(partial: PartialFunction[Class[_], Any]): T = {

// Note that this function is total because it will always return a value (even if null)
// last case is a catch all that lookups in the applicationContext
val totalWireFunction: PartialFunction[Class[_], Any] =
partial.orElse {
// block wiring of clients into anything that is not an Action or Workflow
// NOTE: if they are allowed, 'partial' should already have a matching case for them
case p if p == classOf[KalixClient] =>
throw new BeanCreationException(
s"[${constructor.getDeclaringClass.getSimpleName}] are not allowed to have a dependency on KalixClient")

case p if p == classOf[ComponentClient] =>
throw new BeanCreationException(
s"[${constructor.getDeclaringClass.getSimpleName}] are not allowed to have a dependency on ComponentClient")

case p if p == classOf[WebClientProvider] =>
throw new BeanCreationException(
s"[${constructor.getDeclaringClass.getSimpleName}] are not allowed to have a dependency on WebClientProvider")

// if partial func doesn't match, try to lookup in the applicationContext
case anyOther =>
val bean = applicationContext.getBean(anyOther)
if (bean == null)
throw new BeanCreationException(
s"Cannot wire [${anyOther.getSimpleName}]. Bean not found in the Application Context");
else bean
}

private def setKalixClient[T](clz: Class[T], webClientProviderHolder: WebClientProviderHolder): Unit = {
if (hasContextConstructor(clz, classOf[KalixClient])) {
kalixClient.setWebClient(webClientProviderHolder.webClientProvider.localWebClient)
// we only have one KalixClient, but we only set it to the ThreadLocalFactoryBean
// when building actions, because it's only allowed to inject it in Actions and Workflow Entities
KalixClientFactoryBean.set(kalixClient)
}
}
// all params must be wired so we use 'map' not 'collect'
val params = constructor.getParameterTypes.map(totalWireFunction)

private def setComponentClient[T](clz: Class[T], webClientProviderHolder: WebClientProviderHolder): Unit = {
if (hasContextConstructor(clz, classOf[ComponentClient])) {
kalixClient.setWebClient(webClientProviderHolder.webClientProvider.localWebClient)
ComponentClientFactoryBean.set(componentClient)
}
constructor.newInstance(params: _*)
}

private def eventSourcedEntityProvider[S, E, ES <: EventSourcedEntity[S, E]](
clz: Class[ES]): EventSourcedEntityProvider[S, E, ES] =
ReflectiveEventSourcedEntityProvider.of(
private def actionProvider[A <: Action](clz: Class[A]): ActionProvider[A] =
ReflectiveActionProvider.of(
clz,
messageCodec,
context => {
if (hasContextConstructor(clz, classOf[EventSourcedEntityContext]))
EventSourcedEntityContextFactoryBean.set(context)
kalixBeanFactory.getBean(clz)
})

private def workflowProvider[S, E <: Workflow[S]](clz: Class[E]): WorkflowProvider[S, E] = {
context =>
wiredInstance(clz) {
case p if p == classOf[ActionCreationContext] => context
case p if p == classOf[KalixClient] => kalixClient
case p if p == classOf[ComponentClient] => componentClient
case p if p == classOf[WebClientProvider] => webClientProvider(context)
})

private def workflowProvider[S, W <: Workflow[S]](clz: Class[W]): WorkflowProvider[S, W] = {
ReflectiveWorkflowProvider.of(
clz,
messageCodec,
context => {
if (hasContextConstructor(clz, classOf[WorkflowContext])) {
WorkflowContextFactoryBean.set(context)
}

val webClientProviderHolder = WebClientProviderHolder(context.materializer().system)

setKalixClient(clz, webClientProviderHolder)
setComponentClient(clz, webClientProviderHolder)

val workflowEntity = kalixBeanFactory.getBean(clz)
val workflow =
wiredInstance(clz) {
case p if p == classOf[WorkflowContext] => context
case p if p == classOf[KalixClient] => kalixClient
case p if p == classOf[ComponentClient] => componentClient
case p if p == classOf[WebClientProvider] => webClientProvider(context)
}

val workflowStateType: Class[S] =
workflowEntity.getClass.getGenericSuperclass
workflow.getClass.getGenericSuperclass
.asInstanceOf[ParameterizedType]
.getActualTypeArguments
.head
.asInstanceOf[Class[S]]

messageCodec.registerTypeHints(workflowStateType)

workflowEntity
workflow
.definition()
.getSteps
.asScala
Expand All @@ -461,37 +405,35 @@ case class KalixSpringApplication(applicationContext: ApplicationContext, config
}
.foreach(messageCodec.registerTypeHints)

workflowEntity
workflow
})
}

private def valueEntityProvider[S, E <: ValueEntity[S]](clz: Class[E]): ValueEntityProvider[S, E] =
private def eventSourcedEntityProvider[S, E, ES <: EventSourcedEntity[S, E]](
clz: Class[ES]): EventSourcedEntityProvider[S, E, ES] =
ReflectiveEventSourcedEntityProvider.of(
clz,
messageCodec,
context => wiredInstance(clz) { case p if p == classOf[EventSourcedEntityContext] => context })

private def valueEntityProvider[S, VE <: ValueEntity[S]](clz: Class[VE]): ValueEntityProvider[S, VE] =
ReflectiveValueEntityProvider.of(
clz,
messageCodec,
context => {
if (hasContextConstructor(clz, classOf[ValueEntityContext]))
ValueEntityContextFactoryBean.set(context)
kalixBeanFactory.getBean(clz)
})
context => wiredInstance(clz) { case p if p == classOf[ValueEntityContext] => context })

private def viewProvider[S, V <: View[S]](clz: Class[V]): ViewProvider =
ReflectiveViewProvider.of[S, V](
clz,
messageCodec,
context => {
if (hasContextConstructor(clz, classOf[ViewCreationContext]))
ViewCreationContextFactoryBean.set(context)
kalixBeanFactory.getBean(clz)
})
context => wiredInstance(clz) { case p if p == classOf[ViewCreationContext] => context })

private def multiTableViewProvider[V](clz: Class[V]): ViewProvider =
ReflectiveMultiTableViewProvider.of[V](
clz,
messageCodec,
(viewTableClass, context) => {
if (hasContextConstructor(viewTableClass, classOf[ViewCreationContext]))
ViewCreationContextFactoryBean.set(context)
kalixBeanFactory.getBean(viewTableClass)
val constructor = viewTableClass.getConstructors.head.asInstanceOf[Constructor[View[_]]]
wiredInstance(constructor) { case p if p == classOf[ViewCreationContext] => context }
})
}
Loading

0 comments on commit 102f95c

Please sign in to comment.