diff --git a/rhine-bayes/app/Main.hs b/rhine-bayes/app/Main.hs index 001f35c4..acbf8291 100644 --- a/rhine-bayes/app/Main.hs +++ b/rhine-bayes/app/Main.hs @@ -113,9 +113,9 @@ initialTemperature :: Temperature initialTemperature = 7 -- | We infer the temperature by randomly moving around with a Brownian motion (Wiener process). -temperatureProcess :: (MonadDistribution m, Diff td ~ Double) => BehaviourF m td () Temperature -temperatureProcess = proc () -> do - temperatureFactor <- wienerLogDomain 20 -< () +temperatureProcess :: (MonadDistribution m, Diff td ~ Double) => BehaviourF m td (Diff td) Temperature +temperatureProcess = proc dt -> do + temperatureFactor <- wienerVaryingLogDomain -< dt returnA -< runLogDomain temperatureFactor * initialTemperature -- | Auxiliary conversion function belonging to the log-domain library, see https://github.com/ekmett/log-domain/issues/38 @@ -133,16 +133,33 @@ genModelWithoutTemperature = proc temperature -> do sensor <- generativeModel -< latent returnA -< (sensor, latent) +type ESS = Double + {- | Given sensor data, sample a latent position and a temperature, and weight them according to the likelihood of the observed sensor position. Used to infer position and temperature. -} -posteriorTemperatureProcess :: (MonadMeasure m, Diff td ~ Double) => BehaviourF m td Sensor (Temperature, Pos) -posteriorTemperatureProcess = proc sensor -> do - temperature <- temperatureProcess -< () +posteriorTemperatureProcessAutoESS :: (MonadDistribution m, TimeDomain td, Diff td ~ Double) => BehaviourF (Population m) td Sensor (Temperature, Pos) +posteriorTemperatureProcessAutoESS = withESS 20 posteriorTemperatureProcessLiveESS + +{- | Given sensor data, sample a latent position and a temperature, and weight them according to the likelihood of the observed sensor position. + Used to infer position and temperature. +-} +posteriorTemperatureProcessLiveESS :: (MonadMeasure m, TimeDomain td, Diff td ~ Double) => BehaviourF m td (Sensor, ESS) (Temperature, Pos) +posteriorTemperatureProcessLiveESS = proc (sensor, ess) -> do + t <- sinceStart -< () + temperature <- temperatureProcess -< ess / 2 + t / 20 latent <- prior -< temperature arrM score -< sensorLikelihood latent sensor returnA -< (temperature, latent) +{- | Given sensor data, sample a latent position and a temperature, and weight them according to the likelihood of the observed sensor position. + Used to infer position and temperature. +-} +posteriorTemperatureProcess :: (MonadMeasure m, TimeDomain td, Diff td ~ Double) => BehaviourF m td Sensor (Temperature, Pos) +posteriorTemperatureProcess = proc sensor -> do + posteriorTemperatureProcessLiveESS -< (sensor, 20) + + -- | A collection of all displayable inference results data Result = Result { temperature :: Temperature @@ -234,6 +251,8 @@ mains :: [(String, IO ())] mains = [ ("single rate", mainSingleRate) , ("multi rate, temperature process", mainMultiRate) + , ("multi rate, temperature process, RMSMC", mainMultiRateRMSMC) + , ("multi rate, temperature process, RMSMC dynamic", mainMultiRateRMSMCDyn) ] main :: IO () @@ -247,7 +266,7 @@ main = do {- | Given an actual temperature, simulate a latent position and measured sensor position, and based on the sensor data infer the latent position and the temperature. -} -filtered :: Diff td ~ Double => BehaviourF App td Temperature Result +filtered :: (TimeDomain td, Diff td ~ Double) => BehaviourF App td Temperature Result filtered = proc temperature -> do (measured, latent) <- genModelWithoutTemperature -< temperature particles <- runPopulationCl nParticles resampleSystematic posteriorTemperatureProcess -< measured @@ -261,7 +280,7 @@ filtered = proc temperature -> do } -- | Run simulation, inference, and visualization synchronously -mainClSF :: Diff td ~ Double => BehaviourF App td () () +mainClSF :: (TimeDomain td, Diff td ~ Double) => BehaviourF App td () () mainClSF = proc () -> do output <- filtered -< initialTemperature visualisation -< output @@ -316,9 +335,28 @@ userTemperature = tagS >>> arr (selector >>> fmap Product) >>> mappendS >>> arr inference :: Rhine (GlossConcT IO) (LiftClock IO GlossConcT Busy) (Temperature, (Sensor, Pos)) Result inference = hoistClSF sampleIOGloss inferenceBehaviour @@ liftClock Busy where - inferenceBehaviour :: (MonadDistribution m, Diff td ~ Double, MonadIO m) => BehaviourF m td (Temperature, (Sensor, Pos)) Result + inferenceBehaviour :: (MonadDistribution m, TimeDomain td, Diff td ~ Double, MonadIO m) => BehaviourF m td (Temperature, (Sensor, Pos)) Result + inferenceBehaviour = proc (temperature, (measured, latent)) -> do + particles <- runPopulationCl nParticles (onlyBelowEffectiveSampleSize 60 resampleSystematic) posteriorTemperatureProcessAutoESS -< measured + returnA -< Result{temperature, measured, latent, particles} + +{- | This part performs the inference (and passes along temperature, sensor and position simulations). + It runs as fast as possible, so this will potentially drain the CPU. +-} +inferenceRMSMC :: Rhine (GlossConcT IO) (LiftClock IO GlossConcT Busy) (Temperature, (Sensor, Pos)) Result +inferenceRMSMC = hoistClSF sampleIOGloss inferenceBehaviour @@ liftClock Busy + where + inferenceBehaviour :: (MonadDistribution m, TimeDomain td, Diff td ~ Double, MonadIO m) => BehaviourF m td (Temperature, (Sensor, Pos)) Result + inferenceBehaviour = proc (temperature, (measured, latent)) -> do + particles <- resampleMoveSequentialMonteCarloCl 10 1 resampleSystematic posteriorTemperatureProcess -< measured + returnA -< Result{temperature, measured, latent, particles} + +inferenceRMSMCDyn :: Rhine (GlossConcT IO) (LiftClock IO GlossConcT Busy) (Temperature, (Sensor, Pos)) Result +inferenceRMSMCDyn = hoistClSF sampleIOGloss inferenceBehaviour @@ liftClock Busy + where + inferenceBehaviour :: (MonadDistribution m, TimeDomain td, Diff td ~ Double, MonadIO m) => BehaviourF m td (Temperature, (Sensor, Pos)) Result inferenceBehaviour = proc (temperature, (measured, latent)) -> do - particles <- runPopulationCl nParticles resampleSystematic posteriorTemperatureProcess -< measured + particles <- resampleMoveSequentialMonteCarloDynCl 10 1 (onlyBelowEffectiveSampleSize 5 resampleSystematic) posteriorTemperatureProcess -< measured returnA -< Result{temperature, measured, latent, particles} -- | Visualize the current 'Result' at a rate controlled by the @gloss@ backend, usually 30 FPS. @@ -336,12 +374,45 @@ mainRhineMultiRate = >-- keepLast Result{temperature = initialTemperature, measured = zeroVector, latent = zeroVector, particles = []} -@- glossConcurrently --> visualisationRhine +mainRhineMultiRateRMSMC = + userTemperature + @@ glossClockUTC GlossEventClockIO + >-- keepLast initialTemperature -@- glossConcurrently --> + modelRhine + >-- keepLast (initialTemperature, (zeroVector, zeroVector)) -@- glossConcurrently --> + inferenceRMSMC + >-- keepLast Result{temperature = initialTemperature, measured = zeroVector, latent = zeroVector, particles = []} -@- glossConcurrently --> + visualisationRhine + + +mainRhineMultiRateRMSMCDyn = + userTemperature + @@ glossClockUTC GlossEventClockIO + >-- keepLast initialTemperature -@- glossConcurrently --> + modelRhine + >-- keepLast (initialTemperature, (zeroVector, zeroVector)) -@- glossConcurrently --> + inferenceRMSMCDyn + >-- keepLast Result{temperature = initialTemperature, measured = zeroVector, latent = zeroVector, particles = []} -@- glossConcurrently --> + visualisationRhine + mainMultiRate :: IO () mainMultiRate = void $ launchGlossThread glossSettings $ flow mainRhineMultiRate +mainMultiRateRMSMC :: IO () +mainMultiRateRMSMC = + void $ + launchGlossThread glossSettings $ + flow mainRhineMultiRateRMSMC + +mainMultiRateRMSMCDyn :: IO () +mainMultiRateRMSMCDyn = + void $ + launchGlossThread glossSettings $ + flow mainRhineMultiRateRMSMCDyn + -- * Utilities instance MonadDistribution m => MonadDistribution (GlossConcT m) where diff --git a/rhine-bayes/src/Data/MonadicStreamFunction/Bayes.hs b/rhine-bayes/src/Data/MonadicStreamFunction/Bayes.hs index a33149c8..5997e528 100644 --- a/rhine-bayes/src/Data/MonadicStreamFunction/Bayes.hs +++ b/rhine-bayes/src/Data/MonadicStreamFunction/Bayes.hs @@ -4,6 +4,7 @@ module Data.MonadicStreamFunction.Bayes where import Control.Arrow import Data.Functor (($>)) import Data.Tuple (swap) +import Debug.Trace -- transformers @@ -12,11 +13,19 @@ import Numeric.Log hiding (sum) -- monad-bayes import Control.Monad.Bayes.Population +import Control.Monad.Bayes.Traced -- dunai import Data.MonadicStreamFunction import Data.MonadicStreamFunction.InternalCore (MSF (..)) +import Control.Monad.Trans.Class +import Control.Monad.Bayes.Class (MonadDistribution) +import qualified Control.Monad.Bayes.Traced.Static as Static +import Control.Monad.Bayes.Sequential.Coroutine (hoistFirst) +import Control.Monad.Trans.MSF (performOnFirstSample) +import qualified Control.Monad.Bayes.Traced.Dynamic as Dynamic +-- FIXME rename to sequentialMonteCarlo or smc? -- | Run the Sequential Monte Carlo algorithm continuously on an 'MSF' runPopulationS :: forall m a b. @@ -48,6 +57,115 @@ runPopulationsS resampler = go unzip $ (swap . fmap fst &&& swap . fmap snd) . swap <$> bAndMSFs +resampleMoveSequentialMonteCarlo :: + forall m a b. + MonadDistribution m => + -- (MonadDistribution m, HasTraced t, MonadTrans t) => + -- | Number of particles + Int -> + -- | Number of MC steps + Int -> + -- | Resampler + (forall x. Population m x -> Population m x) -> + MSF (Static.Traced (Population m)) a b -> + -- MSF (t (Population m)) a b -> + -- FIXME Why not MSF m a (Population b) + MSF m a [(b, Log Double)] +resampleMoveSequentialMonteCarlo nParticles nMC resampler = go . Control.Monad.Bayes.Traced.hoist (spawn nParticles >>) . pure + where + go :: + Monad m => + Static.Traced (Population m) (MSF (Static.Traced (Population m)) a b) -> + -- Population m (MSF (t (Population m)) a b) -> + MSF m a [(b, Log Double)] + go msfs = MSF $ \a -> do + -- TODO This is quite different than the dunai version now. Maybe it's right nevertheless. + -- FIXME This normalizes, which introduces bias, whatever that means + let bAndMSFs = composeCopies nMC mhStep + $ Control.Monad.Bayes.Traced.hoist resampler + $ flip unMSF a =<< msfs + bs <- runPopulation $ marginal $ fst <$> bAndMSFs + return (bs, go $ snd <$> bAndMSFs) + +resampleMoveSequentialMonteCarloDynamic :: + forall m a b. + MonadDistribution m => + -- (MonadDistribution m, HasTraced t, MonadTrans t) => + -- | Number of particles + Int -> + -- | Number of MC steps + Int -> + -- | Resampler + (forall x. Population m x -> Population m x) -> + MSF (Dynamic.Traced (Population m)) a b -> + -- MSF (t (Population m)) a b -> + -- FIXME Why not MSF m a (Population b) + MSF m a [(b, Log Double)] +resampleMoveSequentialMonteCarloDynamic nParticles nMC resampler = go . Dynamic.hoist (spawn nParticles >>) . pure + where + go :: + Monad m => + Dynamic.Traced (Population m) (MSF (Dynamic.Traced (Population m)) a b) -> + -- Population m (MSF (t (Population m)) a b) -> + MSF m a [(b, Log Double)] + go msfs = MSF $ \a -> do + -- TODO This is quite different than the dunai version now. Maybe it's right nevertheless. + -- FIXME This normalizes, which introduces bias, whatever that means + let bAndMSFs = Dynamic.freeze + $ composeCopies nMC Dynamic.mhStep + $ Dynamic.hoist resampler + $ flip unMSF a =<< msfs + bs <- runPopulation $ Dynamic.marginal $ fst <$> bAndMSFs + return (bs, go $ snd <$> bAndMSFs) + +-- | Apply a function a given number of times. +composeCopies :: Int -> (a -> a) -> (a -> a) +composeCopies k f = foldr (.) id (replicate k f) + +tracePop :: Monad m => String -> Population m a -> Population m a +-- tracePop msg = fromWeightedList . fmap (\pop -> Debug.Trace.traceShow (msg, length pop) pop) . runPopulation +tracePop _ = id + +-- resampleMoveSequentialMonteCarlo nParticles nMC resampler = morphS marginal $ runPopulationS nParticles $ freeze . composeCopies nMC mhStep . hoistTrace resampler + -- FIXME see PR re-adding this to monad-bayes normalize :: Monad m => Population m a -> Population m a normalize = fromWeightedList . fmap (\particles -> second (/ (sum $ snd <$> particles)) <$> particles) . runPopulation + +-- FIXME See PR to monad-bayes + + +-- | Only use the given resampler when the effective sample size is below a certain threshold +onlyBelowEffectiveSampleSize :: + MonadDistribution m => + -- | The threshold under which the effective sample size must fall before the resampler is used. + -- For example, this may be half of the number of particles. + Double -> + -- | The resampler to user under the threshold + (forall n . MonadDistribution n => Population n a -> Population n a) -> + -- | The new resampler + (Population m a -> Population m a) +onlyBelowEffectiveSampleSize threshold resampler pop = do + (particles, ess) <- lift $ runWithEffectiveSampleSize pop + let newPop = fromWeightedList $ pure particles + -- This assumes that the resampler does not mutate the m effects, as it should + if ess < threshold then resampler newPop else newPop + +-- | Compute the effective sample size of a population from the weights. +-- +-- See https://en.wikipedia.org/wiki/Design_effect#Effective_sample_size +runWithEffectiveSampleSize :: Functor m => Population m a -> m ([(a, Log Double)], Double) +runWithEffectiveSampleSize = fmap (id &&& (effectiveSampleSizeKish . map (exp . ln . snd))) . runPopulation + where + effectiveSampleSizeKish :: [Double] -> Double + effectiveSampleSizeKish weights = square (sum weights) / sum (square <$> weights) + square :: Double -> Double + square x = x * x + +measureESS :: Monad m => MSF (Population m) a b -> MSF (Population m) a (b, Double) +measureESS = morphGS $ fmap $ \pop -> fromWeightedList $ do + (particles, ess) <- runWithEffectiveSampleSize pop + pure $ map (first (first (, ess))) particles + +withESS :: Monad m => Double -> MSF (Population m) (a, Double) b -> MSF (Population m) a b +withESS initESS = feedback initESS . measureESS diff --git a/rhine-bayes/src/FRP/Rhine/Bayes.hs b/rhine-bayes/src/FRP/Rhine/Bayes.hs index 5422a0c3..fd4374f7 100644 --- a/rhine-bayes/src/FRP/Rhine/Bayes.hs +++ b/rhine-bayes/src/FRP/Rhine/Bayes.hs @@ -1,4 +1,5 @@ -module FRP.Rhine.Bayes where +module FRP.Rhine.Bayes + (module FRP.Rhine.Bayes, module X) where -- log-domain import Numeric.Log hiding (sum) @@ -13,8 +14,12 @@ import qualified Control.Monad.Trans.MSF.Reader as DunaiReader -- dunai-bayes import qualified Data.MonadicStreamFunction.Bayes as DunaiBayes +import Data.MonadicStreamFunction.Bayes as X (onlyBelowEffectiveSampleSize) + -- rhine import FRP.Rhine +import qualified Control.Monad.Bayes.Traced.Static as Static +import qualified Control.Monad.Bayes.Traced.Dynamic as Dynamic -- * Inference methods @@ -32,6 +37,39 @@ runPopulationCl :: forall m cl a b . Monad m => -> ClSF m cl a [(b, Log Double)] runPopulationCl nParticles resampler = DunaiReader.readerS . DunaiBayes.runPopulationS nParticles resampler . DunaiReader.runReaderS +-- | Run the Resample Move Sequential Monte Carlo algorithm continuously on a 'ClSF'. +resampleMoveSequentialMonteCarloCl :: forall m cl a b . (MonadDistribution m) => +-- resampleMoveSequentialMonteCarloCl :: forall t m cl a b . (MonadDistribution m, HasTraced t, MonadTrans t) => + -- | Number of particles + Int -> + -- | Number of MC steps + Int -> + -- | Resampler (see 'Control.Monad.Bayes.Population' for some standard choices) + (forall x . Population m x -> Population m x) + -- | A signal function modelling the stochastic process on which to perform inference. + -- @a@ represents observations upon which the model should condition, using e.g. 'score'. + -- It can also additionally contain hyperparameters. + -- @b@ is the type of estimated current state. + -> ClSF (Static.Traced (Population m)) cl a b + -> ClSF m cl a [(b, Log Double)] +resampleMoveSequentialMonteCarloCl nParticles nMC resampler = DunaiReader.readerS . DunaiBayes.resampleMoveSequentialMonteCarlo nParticles nMC resampler . DunaiReader.runReaderS + +resampleMoveSequentialMonteCarloDynCl :: forall m cl a b . (MonadDistribution m) => +-- resampleMoveSequentialMonteCarloCl :: forall t m cl a b . (MonadDistribution m, HasTraced t, MonadTrans t) => + -- | Number of particles + Int -> + -- | Number of MC steps + Int -> + -- | Resampler (see 'Control.Monad.Bayes.Population' for some standard choices) + (forall x . Population m x -> Population m x) + -- | A signal function modelling the stochastic process on which to perform inference. + -- @a@ represents observations upon which the model should condition, using e.g. 'score'. + -- It can also additionally contain hyperparameters. + -- @b@ is the type of estimated current state. + -> ClSF (Dynamic.Traced (Population m)) cl a b + -> ClSF m cl a [(b, Log Double)] +resampleMoveSequentialMonteCarloDynCl nParticles nMC resampler = DunaiReader.readerS . DunaiBayes.resampleMoveSequentialMonteCarloDynamic nParticles nMC resampler . DunaiReader.runReaderS + -- * Short standard library of stochastic processes -- | White noise, that is, an independent normal distribution at every time step. @@ -83,3 +121,8 @@ wienerVaryingLogDomain :: (MonadDistribution m, Diff td ~ Double) => BehaviourF m td (Diff td) (Log Double) wienerVaryingLogDomain = wienerVarying >>> arr Exp + +withESS :: Monad m => Double -> ClSF (Population m) cl (a, Double) b -> ClSF (Population m) cl a b +withESS initESS = DunaiReader.readerS . DunaiBayes.withESS initESS . (arr assoc >>>) . DunaiReader.runReaderS + where + assoc ((ti, a), td) = (ti, (a, td))