diff --git a/entities/src/main/scala/com/devsisters/shardcake/internal/EntityManager.scala b/entities/src/main/scala/com/devsisters/shardcake/internal/EntityManager.scala index 2292f6f..8efca13 100644 --- a/entities/src/main/scala/com/devsisters/shardcake/internal/EntityManager.scala +++ b/entities/src/main/scala/com/devsisters/shardcake/internal/EntityManager.scala @@ -29,26 +29,29 @@ private[shardcake] object EntityManager { entityMaxIdleTime: Option[Duration] ): URIO[R, EntityManager[Req]] = for { - entities <- Ref.Synchronized.make[Map[String, (Either[Queue[Req], Signal], EpochMillis)]](Map()) - env <- ZIO.environment[R] + entities <- Ref.Synchronized.make[Map[String, Either[Queue[Req], Signal]]](Map()) + entitiesLastReceivedAt <- Ref.make[Map[String, EpochMillis]](Map()) + env <- ZIO.environment[R] } yield new EntityManagerLive[Req]( recipientType, (entityId: String, queue: Queue[Req]) => behavior(entityId, queue).provideEnvironment(env), terminateMessage, entities, + entitiesLastReceivedAt, sharding, config, entityMaxIdleTime ) - private def currentTimeInMilliseconds: UIO[EpochMillis] = + private val currentTimeInMilliseconds: UIO[EpochMillis] = Clock.currentTime(TimeUnit.MILLISECONDS) class EntityManagerLive[Req]( recipientType: RecipientType[Req], behavior: (String, Queue[Req]) => Task[Nothing], terminateMessage: Signal => Option[Req], - entities: Ref.Synchronized[Map[String, (Either[Queue[Req], Signal], EpochMillis)]], + entities: Ref.Synchronized[Map[String, Either[Queue[Req], Signal]]], + entitiesLastReceivedAt: Ref[Map[String, EpochMillis]], sharding: Sharding, config: Config, entityMaxIdleTime: Option[Duration] @@ -56,14 +59,15 @@ private[shardcake] object EntityManager { private def startExpirationFiber(entityId: String): UIO[Fiber[Nothing, Unit]] = { val maxIdleTime = entityMaxIdleTime getOrElse config.entityMaxIdleTime - def sleep(duration: Duration): UIO[Unit] = for { - _ <- Clock.sleep(duration) - cdt <- currentTimeInMilliseconds - map <- entities.get - lastReceivedAt = map.get(entityId).map { case (_, lastReceivedAt) => lastReceivedAt }.getOrElse(0L) - remaining = maxIdleTime minus Duration.fromMillis(cdt - lastReceivedAt) - _ <- sleep(remaining).when(remaining > Duration.Zero) - } yield () + def sleep(duration: Duration): UIO[Unit] = + for { + _ <- Clock.sleep(duration) + cdt <- currentTimeInMilliseconds + map <- entitiesLastReceivedAt.get + lastReceivedAt = map.getOrElse(entityId, 0L) + remaining = maxIdleTime minus Duration.fromMillis(cdt - lastReceivedAt) + _ <- sleep(remaining).when(remaining > Duration.Zero) + } yield () (for { _ <- sleep(maxIdleTime) @@ -74,19 +78,19 @@ private[shardcake] object EntityManager { private def terminateEntity(entityId: String): UIO[Unit] = entities.updateZIO(map => map.get(entityId) match { - case Some((Left(queue), lastReceivedAt)) => + case Some(Left(queue)) => Promise .make[Nothing, Unit] .flatMap { p => terminateMessage(p) match { case Some(msg) => // if a queue is found, offer the termination message, and set the queue to None so that no new message is enqueued - queue.offer(msg).exit.as(map.updated(entityId, (Right(p), lastReceivedAt))) + queue.offer(msg).exit.as(map.updated(entityId, Right(p))) case None => queue.shutdown.as(map - entityId) } } - case _ => + case _ => // if no queue is found, do nothing ZIO.succeed(map) } @@ -108,50 +112,60 @@ private[shardcake] object EntityManager { case _: TopicType[_] => ZIO.unit } // find the queue for that entity, or create it if needed - queue <- entities.modifyZIO(map => - map.get(entityId) match { - case Some((queue @ Left(_), _)) => - // queue exists, delay the interruption fiber and return the queue - currentTimeInMilliseconds.map(cdt => (queue, map.updated(entityId, (queue, cdt)))) - case Some((p @ Right(_), _)) => - // the queue is shutting down, stash and retry - ZIO.succeed((p, map)) - case None => - sharding.isShuttingDown.flatMap { - case true => - // don't start any fiber while sharding is shutting down - ZIO.fail(EntityNotManagedByThisPod(entityId)) - case false => - // queue doesn't exist, create a new one - for { - queue <- Queue.unbounded[Req] - // start the expiration fiber - expirationFiber <- startExpirationFiber(entityId) - _ <- behavior(entityId, queue) - .ensuring( - // shutdown the queue when the fiber ends - entities.update(_ - entityId) *> queue.shutdown *> expirationFiber.interrupt - ) - .forkDaemon - cdt <- currentTimeInMilliseconds - leftQueue = Left(queue) - } yield (leftQueue, map.updated(entityId, (leftQueue, cdt))) - } - } - ) + map <- entities.get + queue <- map.get(entityId) match { + case Some(queue @ Left(_)) => ZIO.succeed(queue) + case _ => getOrCreateQueue(entityId) + } _ <- queue match { case Right(_) => // the queue is shutting down, try again a little later Clock.sleep(100 millis) *> send(entityId, req, replyId, replyChannel) case Left(queue) => - // add the message to the queue and setup the reply channel if needed - (replyId match { - case Some(replyId) => sharding.initReply(replyId, replyChannel) *> queue.offer(req) - case None => queue.offer(req) *> replyChannel.end - }).catchAllCause(_ => send(entityId, req, replyId, replyChannel)) + currentTimeInMilliseconds.flatMap(cdt => entitiesLastReceivedAt.update(_ + (entityId -> cdt))) *> + // add the message to the queue and setup the reply channel if needed + (replyId match { + case Some(replyId) => sharding.initReply(replyId, replyChannel) *> queue.offer(req) + case None => queue.offer(req) *> replyChannel.end + }).catchAllCause(_ => send(entityId, req, replyId, replyChannel)) } } yield () + private def getOrCreateQueue(entityId: String): IO[EntityNotManagedByThisPod, Either[Queue[Req], Signal]] = + entities.modifyZIO(map => + map.get(entityId) match { + case Some(queue @ Left(_)) => + // the queue already exists, return it + ZIO.succeed((queue, map)) + case Some(p @ Right(_)) => + // the queue is shutting down, stash and retry + ZIO.succeed((p, map)) + case None => + sharding.isShuttingDown.flatMap { + case true => + // don't start any fiber while sharding is shutting down + ZIO.fail(EntityNotManagedByThisPod(entityId)) + case false => + // queue doesn't exist, create a new one + for { + queue <- Queue.unbounded[Req] + // start the expiration fiber + expirationFiber <- startExpirationFiber(entityId) + _ <- behavior(entityId, queue) + .ensuring( + // shutdown the queue when the fiber ends + entities.update(_ - entityId) *> + entitiesLastReceivedAt.update(_ - entityId) *> + queue.shutdown *> + expirationFiber.interrupt + ) + .forkDaemon + leftQueue = Left(queue) + } yield (leftQueue, map.updated(entityId, leftQueue)) + } + } + ) + def terminateEntitiesOnShards(shards: Set[ShardId]): UIO[Unit] = entities.modify { entities => // get all entities on the given shards to terminate them @@ -162,12 +176,10 @@ private[shardcake] object EntityManager { def terminateAllEntities: UIO[Unit] = entities.getAndSet(Map()).flatMap(terminateEntities) - private def terminateEntities( - entitiesToTerminate: Map[String, (Either[Queue[Req], Signal], EpochMillis)] - ): UIO[Unit] = + private def terminateEntities(entitiesToTerminate: Map[String, Either[Queue[Req], Signal]]): UIO[Unit] = for { // send termination message to all entities - promises <- ZIO.foreach(entitiesToTerminate.toList) { case (_, (queue, _)) => + promises <- ZIO.foreach(entitiesToTerminate.toList) { case (_, queue) => Promise .make[Nothing, Unit] .flatMap(p =>