Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: remove thread local based wiring #1817

Merged
merged 4 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,18 @@

package kalix.spring.impl

import java.lang.reflect.Constructor
import java.lang.reflect.Modifier
import java.lang.reflect.ParameterizedType

import scala.concurrent.Future
import scala.jdk.CollectionConverters.CollectionHasAsScala
import scala.jdk.FutureConverters.CompletionStageOps
import scala.jdk.OptionConverters.RichOption

import akka.Done
import com.typesafe.config.Config
import kalix.javasdk.Context
import kalix.javasdk.Kalix
import kalix.javasdk.action.Action
import kalix.javasdk.action.ActionCreationContext
Expand Down Expand Up @@ -55,7 +65,6 @@ import kalix.spring.WebClientProvider
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.springframework.beans.factory.BeanCreationException
import org.springframework.beans.factory.FactoryBean
import org.springframework.beans.factory.config.BeanDefinition
import org.springframework.beans.factory.config.ConfigurableBeanFactory
import org.springframework.beans.factory.support.AbstractBeanDefinition
Expand All @@ -67,15 +76,6 @@ import org.springframework.core.`type`.classreading.MetadataReader
import org.springframework.core.`type`.classreading.MetadataReaderFactory
import org.springframework.core.`type`.filter.TypeFilter

import java.lang.reflect.Modifier
import java.lang.reflect.ParameterizedType
import scala.concurrent.Future
import scala.jdk.CollectionConverters.CollectionHasAsScala
import scala.jdk.FutureConverters.CompletionStageOps
import scala.jdk.OptionConverters.RichOption
import scala.jdk.OptionConverters._
import scala.reflect.ClassTag

object KalixSpringApplication {

val kalixComponents: Seq[Class[_]] =
Expand Down Expand Up @@ -115,7 +115,7 @@ object KalixSpringApplication {
* The enhanced variant doesn't contain all the annotations, but only the SpringBootApplication one. Therefore, we
* need to lookup for the original one. We need it to find the default ACL annotation.
*/
class MainClassProvider(cglibMain: Class[_]) extends ClassPathScanningCandidateComponentProvider {
private class MainClassProvider(cglibMain: Class[_]) extends ClassPathScanningCandidateComponentProvider {

private object OriginalMainClassFilter extends TypeFilter {
override def `match`(metadataReader: MetadataReader, metadataReaderFactory: MetadataReaderFactory): Boolean = {
Expand Down Expand Up @@ -149,7 +149,7 @@ object KalixSpringApplication {
* This class will do exactly this. It find them and return tweaked BeanDefinitions (eg :prototype scope and autowired
* by constructor)
*/
class KalixComponentProvider(cglibMain: Class[_]) extends ClassPathScanningCandidateComponentProvider {
private class KalixComponentProvider(cglibMain: Class[_]) extends ClassPathScanningCandidateComponentProvider {

private object KalixComponentTypeFilter extends TypeFilter {
override def `match`(metadataReader: MetadataReader, metadataReaderFactory: MetadataReaderFactory): Boolean = {
Expand Down Expand Up @@ -186,66 +186,6 @@ object KalixSpringApplication {
}
}

abstract class ThreadLocalFactoryBean[T: ClassTag] extends FactoryBean[T] {
val threadLocal = new ThreadLocal[T]

def set(value: T) = threadLocal.set(value)

override def getObject: T = threadLocal.get()

override def getObjectType: Class[_] = implicitly[ClassTag[T]].runtimeClass
}

object ActionCreationContextFactoryBean extends ThreadLocalFactoryBean[ActionCreationContext] {
// ActionCreationContext is a singleton, so strictly speaking this could return 'true'
// However, we still need the ThreadLocal hack to let Spring have access to it.
// Also, we don't want to give direct access to it because we want to provide different ActionCreationContext impl
// depending if it's used in prod code or during tests.
override def isSingleton: Boolean = false
}

object EventSourcedEntityContextFactoryBean extends ThreadLocalFactoryBean[EventSourcedEntityContext] {
override def isSingleton: Boolean = false // never!!
}

object WorkflowContextFactoryBean extends ThreadLocalFactoryBean[WorkflowContext] {
override def isSingleton: Boolean = false // never!!
}

object ValueEntityContextFactoryBean extends ThreadLocalFactoryBean[ValueEntityContext] {
override def isSingleton: Boolean = false // never!!
}

object ViewCreationContextFactoryBean extends ThreadLocalFactoryBean[ViewCreationContext] {
override def isSingleton: Boolean = false // never!!
}

object KalixClientFactoryBean extends ThreadLocalFactoryBean[KalixClient] {
override def isSingleton: Boolean = true // yes, we only need one

override def getObject: KalixClient =
if (threadLocal.get() != null) threadLocal.get()
else
throw new BeanCreationException("KalixClient can only be injected in Kalix Actions and Workflows.")
}

object ComponentClientFactoryBean extends ThreadLocalFactoryBean[ComponentClient] {
override def isSingleton: Boolean = true // yes, we only need one

override def getObject: ComponentClient =
if (threadLocal.get() != null) threadLocal.get()
else
throw new BeanCreationException("ComponentClient can only be injected in Kalix Actions and Workflows.")
}

object WebClientProviderFactoryBean extends ThreadLocalFactoryBean[WebClientProvider] {
override def isSingleton: Boolean = true // yes, we only need one

override def getObject: WebClientProvider =
if (threadLocal.get() != null) threadLocal.get()
else
throw new BeanCreationException("WebClientProvider can only be injected in Kalix Actions and Workflows.")
}
}

case class KalixSpringApplication(applicationContext: ApplicationContext, config: Config) {
Expand All @@ -260,15 +200,6 @@ case class KalixSpringApplication(applicationContext: ApplicationContext, config

private val kalixBeanFactory = new DefaultListableBeanFactory(applicationContext)

kalixBeanFactory.registerSingleton("actionCreationContextFactoryBean", ActionCreationContextFactoryBean)
kalixBeanFactory.registerSingleton("eventSourcedEntityContext", EventSourcedEntityContextFactoryBean)
kalixBeanFactory.registerSingleton("workflowEntityContext", WorkflowContextFactoryBean)
kalixBeanFactory.registerSingleton("valueEntityContext", ValueEntityContextFactoryBean)
kalixBeanFactory.registerSingleton("viewCreationContext", ViewCreationContextFactoryBean)
kalixBeanFactory.registerSingleton("kalixClient", KalixClientFactoryBean)
kalixBeanFactory.registerSingleton("componentClient", ComponentClientFactoryBean)
kalixBeanFactory.registerSingleton("webClientProvider", WebClientProviderFactoryBean)

// there should be only one class annotated with SpringBootApplication in the applicationContext
private val cglibEnhanceMainClass =
applicationContext.getBeansWithAnnotation(classOf[SpringBootApplication]).values().asScala.head
Expand All @@ -284,7 +215,7 @@ case class KalixSpringApplication(applicationContext: ApplicationContext, config
provider.setEnvironment(applicationContext.getEnvironment) //use the same environment to get access to properties

// load all Kalix components found in the classpath
val classBeanMap =
private val classBeanMap =
provider.findKalixComponents.map { bean =>
// here we need to load the components using the same loader as the Main class
// this is needed to have it loaded in the RestartClassLoader when using auto-reload
Expand All @@ -293,7 +224,7 @@ case class KalixSpringApplication(applicationContext: ApplicationContext, config
}.toMap

// each loaded class needs to be validated before registration
val validation =
private val validation =
classBeanMap.keySet
.foldLeft(Valid: Validation) { case (validations, cls) =>
validations ++ Validations.validate(cls)
Expand Down Expand Up @@ -368,88 +299,101 @@ case class KalixSpringApplication(applicationContext: ApplicationContext, config

def port: Int = kalixRunner.configuration.userFunctionPort

/* Each component may have a creation context passed to its constructor.
* This method checks if there is a constructor in `clz` that receives a `context`.
*/
private def hasContextConstructor(clz: Class[_], contextType: Class[_]): Boolean =
clz.getConstructors.exists { ctor =>
ctor.getParameterTypes.contains(contextType)
}

private def actionProvider[A <: Action](clz: Class[A]): ActionProvider[A] =
ReflectiveActionProvider.of(
clz,
messageCodec,
context => {
if (hasContextConstructor(clz, classOf[ActionCreationContext]))
ActionCreationContextFactoryBean.set(context)

val webClientProviderHolder = WebClientProviderHolder(context.materializer().system)

setKalixClient(clz, webClientProviderHolder)
setComponentClient(clz, webClientProviderHolder)
private def webClientProvider(context: Context) = {
val webClientProviderHolder = WebClientProviderHolder(context.materializer().system)
webClientProviderHolder.webClientProvider
}

if (hasContextConstructor(clz, classOf[WebClientProvider])) {
val webClientProvider = webClientProviderHolder.webClientProvider
WebClientProviderFactoryBean.set(webClientProvider)
}
/**
* Create an instance of `clz` using the mappings defined in `partial`. Each component provider should define what are
* the acceptable dependencies in the partial function.
*
* If the partial function doesn't match, it will try to lookup in the Spring applicationContext.
*/
private def wiredInstance[T](clz: Class[T])(partial: PartialFunction[Class[_], Any]): T = {
// only one constructor allowed
require(clz.getDeclaredConstructors.length > 1, s"Class [${clz.getSimpleName}] must have only one constructor")
wiredInstance(clz.getDeclaredConstructors.head.asInstanceOf[Constructor[T]])(partial)
}

kalixBeanFactory.getBean(clz)
})
/**
* Create an instance using the passed `constructor` and the mappings defined in `partial`.
*
* Each component provider should define what are the acceptable dependencies in the partial function.
*
* If the partial function doesn't match, it will try to lookup in the Spring applicationContext.
*/
private def wiredInstance[T](constructor: Constructor[T])(partial: PartialFunction[Class[_], Any]): T = {

// Note that this function is total because it will always return a value (even if null)
// last case is a catch all that lookups in the applicationContext
val totalWireFunction: PartialFunction[Class[_], Any] =
partial.orElse {
// block wiring of clients into anything that is not an Action or Workflow
// NOTE: if they are allowed, 'partial' should already have a matching case for them
case p if p == classOf[KalixClient] =>
throw new BeanCreationException(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look @johanandren, now throwing for real.

s"[${constructor.getDeclaringClass.getSimpleName}] are not allowed to have a dependency on KalixClient")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I re-phrased the text here because I was using the article 'a' in previous commit. It would be wrong for EventSourcedEntity.


case p if p == classOf[ComponentClient] =>
throw new BeanCreationException(
s"[${constructor.getDeclaringClass.getSimpleName}] are not allowed to have a dependency on ComponentClient")

case p if p == classOf[WebClientProvider] =>
throw new BeanCreationException(
s"[${constructor.getDeclaringClass.getSimpleName}] are not allowed to have a dependency on WebClientProvider")

// if partial func doesn't match, try to lookup in the applicationContext
case anyOther =>
val bean = applicationContext.getBean(anyOther)
if (bean == null)
throw new BeanCreationException(
s"Cannot wire [${anyOther.getSimpleName}]. Bean not found in the Application Context");
Comment on lines +349 to +351
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also remove the check on wiring wrong context.

This more generic check will cover all wrong contexts, but also cases in which the user didn't properly define their own bean.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can go one extra mile and provide specific error messages when a user tries to inject a ActionContext into an EventSourcedEntity, for example.

But maybe that's something we should do using our Validation frameworky.

else bean
}

private def setKalixClient[T](clz: Class[T], webClientProviderHolder: WebClientProviderHolder): Unit = {
if (hasContextConstructor(clz, classOf[KalixClient])) {
kalixClient.setWebClient(webClientProviderHolder.webClientProvider.localWebClient)
// we only have one KalixClient, but we only set it to the ThreadLocalFactoryBean
// when building actions, because it's only allowed to inject it in Actions and Workflow Entities
KalixClientFactoryBean.set(kalixClient)
}
}
// all params must be wired so we use 'map' not 'collect'
val params = constructor.getParameterTypes.map(totalWireFunction)

private def setComponentClient[T](clz: Class[T], webClientProviderHolder: WebClientProviderHolder): Unit = {
if (hasContextConstructor(clz, classOf[ComponentClient])) {
kalixClient.setWebClient(webClientProviderHolder.webClientProvider.localWebClient)
ComponentClientFactoryBean.set(componentClient)
}
constructor.newInstance(params: _*)
}

private def eventSourcedEntityProvider[S, E, ES <: EventSourcedEntity[S, E]](
clz: Class[ES]): EventSourcedEntityProvider[S, E, ES] =
ReflectiveEventSourcedEntityProvider.of(
private def actionProvider[A <: Action](clz: Class[A]): ActionProvider[A] =
ReflectiveActionProvider.of(
clz,
messageCodec,
context => {
if (hasContextConstructor(clz, classOf[EventSourcedEntityContext]))
EventSourcedEntityContextFactoryBean.set(context)
kalixBeanFactory.getBean(clz)
})

private def workflowProvider[S, E <: Workflow[S]](clz: Class[E]): WorkflowProvider[S, E] = {
context =>
wiredInstance(clz) {
case p if p == classOf[ActionCreationContext] => context
case p if p == classOf[KalixClient] => kalixClient
case p if p == classOf[ComponentClient] => componentClient
case p if p == classOf[WebClientProvider] => webClientProvider(context)
})

private def workflowProvider[S, W <: Workflow[S]](clz: Class[W]): WorkflowProvider[S, W] = {
ReflectiveWorkflowProvider.of(
clz,
messageCodec,
context => {
if (hasContextConstructor(clz, classOf[WorkflowContext])) {
WorkflowContextFactoryBean.set(context)
}

val webClientProviderHolder = WebClientProviderHolder(context.materializer().system)

setKalixClient(clz, webClientProviderHolder)
setComponentClient(clz, webClientProviderHolder)

val workflowEntity = kalixBeanFactory.getBean(clz)
val workflow =
wiredInstance(clz) {
case p if p == classOf[WorkflowContext] => context
case p if p == classOf[KalixClient] => kalixClient
case p if p == classOf[ComponentClient] => componentClient
case p if p == classOf[WebClientProvider] => webClientProvider(context)
}

val workflowStateType: Class[S] =
workflowEntity.getClass.getGenericSuperclass
workflow.getClass.getGenericSuperclass
.asInstanceOf[ParameterizedType]
.getActualTypeArguments
.head
.asInstanceOf[Class[S]]

messageCodec.registerTypeHints(workflowStateType)

workflowEntity
workflow
.definition()
.getSteps
.asScala
Expand All @@ -461,37 +405,35 @@ case class KalixSpringApplication(applicationContext: ApplicationContext, config
}
.foreach(messageCodec.registerTypeHints)

workflowEntity
workflow
})
}

private def valueEntityProvider[S, E <: ValueEntity[S]](clz: Class[E]): ValueEntityProvider[S, E] =
private def eventSourcedEntityProvider[S, E, ES <: EventSourcedEntity[S, E]](
clz: Class[ES]): EventSourcedEntityProvider[S, E, ES] =
ReflectiveEventSourcedEntityProvider.of(
clz,
messageCodec,
context => wiredInstance(clz) { case p if p == classOf[EventSourcedEntityContext] => context })

private def valueEntityProvider[S, VE <: ValueEntity[S]](clz: Class[VE]): ValueEntityProvider[S, VE] =
ReflectiveValueEntityProvider.of(
clz,
messageCodec,
context => {
if (hasContextConstructor(clz, classOf[ValueEntityContext]))
ValueEntityContextFactoryBean.set(context)
kalixBeanFactory.getBean(clz)
})
context => wiredInstance(clz) { case p if p == classOf[ValueEntityContext] => context })

private def viewProvider[S, V <: View[S]](clz: Class[V]): ViewProvider =
ReflectiveViewProvider.of[S, V](
clz,
messageCodec,
context => {
if (hasContextConstructor(clz, classOf[ViewCreationContext]))
ViewCreationContextFactoryBean.set(context)
kalixBeanFactory.getBean(clz)
})
context => wiredInstance(clz) { case p if p == classOf[ViewCreationContext] => context })

private def multiTableViewProvider[V](clz: Class[V]): ViewProvider =
ReflectiveMultiTableViewProvider.of[V](
clz,
messageCodec,
(viewTableClass, context) => {
if (hasContextConstructor(viewTableClass, classOf[ViewCreationContext]))
ViewCreationContextFactoryBean.set(context)
kalixBeanFactory.getBean(viewTableClass)
val constructor = viewTableClass.getConstructors.head.asInstanceOf[Constructor[View[_]]]
wiredInstance(constructor) { case p if p == classOf[ViewCreationContext] => context }
})
}
Loading
Loading