Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Effective Sample Size & tempering/annealing #222

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 81 additions & 10 deletions rhine-bayes/app/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ initialTemperature :: Temperature
initialTemperature = 7

-- | We infer the temperature by randomly moving around with a Brownian motion (Wiener process).
temperatureProcess :: (MonadDistribution m, Diff td ~ Double) => BehaviourF m td () Temperature
temperatureProcess = proc () -> do
temperatureFactor <- wienerLogDomain 20 -< ()
temperatureProcess :: (MonadDistribution m, Diff td ~ Double) => BehaviourF m td (Diff td) Temperature
temperatureProcess = proc dt -> do
temperatureFactor <- wienerVaryingLogDomain -< dt
returnA -< runLogDomain temperatureFactor * initialTemperature

-- | Auxiliary conversion function belonging to the log-domain library, see https://github.com/ekmett/log-domain/issues/38
Expand All @@ -133,16 +133,33 @@ genModelWithoutTemperature = proc temperature -> do
sensor <- generativeModel -< latent
returnA -< (sensor, latent)

type ESS = Double

{- | Given sensor data, sample a latent position and a temperature, and weight them according to the likelihood of the observed sensor position.
Used to infer position and temperature.
-}
posteriorTemperatureProcess :: (MonadMeasure m, Diff td ~ Double) => BehaviourF m td Sensor (Temperature, Pos)
posteriorTemperatureProcess = proc sensor -> do
temperature <- temperatureProcess -< ()
posteriorTemperatureProcessAutoESS :: (MonadDistribution m, TimeDomain td, Diff td ~ Double) => BehaviourF (Population m) td Sensor (Temperature, Pos)
posteriorTemperatureProcessAutoESS = withESS 20 posteriorTemperatureProcessLiveESS

{- | Given sensor data, sample a latent position and a temperature, and weight them according to the likelihood of the observed sensor position.
Used to infer position and temperature.
-}
posteriorTemperatureProcessLiveESS :: (MonadMeasure m, TimeDomain td, Diff td ~ Double) => BehaviourF m td (Sensor, ESS) (Temperature, Pos)
posteriorTemperatureProcessLiveESS = proc (sensor, ess) -> do
t <- sinceStart -< ()
temperature <- temperatureProcess -< ess / 2 + t / 20
latent <- prior -< temperature
arrM score -< sensorLikelihood latent sensor
returnA -< (temperature, latent)

{- | Given sensor data, sample a latent position and a temperature, and weight them according to the likelihood of the observed sensor position.
Used to infer position and temperature.
-}
posteriorTemperatureProcess :: (MonadMeasure m, TimeDomain td, Diff td ~ Double) => BehaviourF m td Sensor (Temperature, Pos)
posteriorTemperatureProcess = proc sensor -> do
posteriorTemperatureProcessLiveESS -< (sensor, 20)


-- | A collection of all displayable inference results
data Result = Result
{ temperature :: Temperature
Expand Down Expand Up @@ -234,6 +251,8 @@ mains :: [(String, IO ())]
mains =
[ ("single rate", mainSingleRate)
, ("multi rate, temperature process", mainMultiRate)
, ("multi rate, temperature process, RMSMC", mainMultiRateRMSMC)
, ("multi rate, temperature process, RMSMC dynamic", mainMultiRateRMSMCDyn)
]

main :: IO ()
Expand All @@ -247,7 +266,7 @@ main = do
{- | Given an actual temperature, simulate a latent position and measured sensor position,
and based on the sensor data infer the latent position and the temperature.
-}
filtered :: Diff td ~ Double => BehaviourF App td Temperature Result
filtered :: (TimeDomain td, Diff td ~ Double) => BehaviourF App td Temperature Result
filtered = proc temperature -> do
(measured, latent) <- genModelWithoutTemperature -< temperature
particles <- runPopulationCl nParticles resampleSystematic posteriorTemperatureProcess -< measured
Expand All @@ -261,7 +280,7 @@ filtered = proc temperature -> do
}

-- | Run simulation, inference, and visualization synchronously
mainClSF :: Diff td ~ Double => BehaviourF App td () ()
mainClSF :: (TimeDomain td, Diff td ~ Double) => BehaviourF App td () ()
mainClSF = proc () -> do
output <- filtered -< initialTemperature
visualisation -< output
Expand Down Expand Up @@ -316,9 +335,28 @@ userTemperature = tagS >>> arr (selector >>> fmap Product) >>> mappendS >>> arr
inference :: Rhine (GlossConcT IO) (LiftClock IO GlossConcT Busy) (Temperature, (Sensor, Pos)) Result
inference = hoistClSF sampleIOGloss inferenceBehaviour @@ liftClock Busy
where
inferenceBehaviour :: (MonadDistribution m, Diff td ~ Double, MonadIO m) => BehaviourF m td (Temperature, (Sensor, Pos)) Result
inferenceBehaviour :: (MonadDistribution m, TimeDomain td, Diff td ~ Double, MonadIO m) => BehaviourF m td (Temperature, (Sensor, Pos)) Result
inferenceBehaviour = proc (temperature, (measured, latent)) -> do
particles <- runPopulationCl nParticles (onlyBelowEffectiveSampleSize 60 resampleSystematic) posteriorTemperatureProcessAutoESS -< measured
returnA -< Result{temperature, measured, latent, particles}

{- | This part performs the inference (and passes along temperature, sensor and position simulations).
It runs as fast as possible, so this will potentially drain the CPU.
-}
inferenceRMSMC :: Rhine (GlossConcT IO) (LiftClock IO GlossConcT Busy) (Temperature, (Sensor, Pos)) Result
inferenceRMSMC = hoistClSF sampleIOGloss inferenceBehaviour @@ liftClock Busy
where
inferenceBehaviour :: (MonadDistribution m, TimeDomain td, Diff td ~ Double, MonadIO m) => BehaviourF m td (Temperature, (Sensor, Pos)) Result
inferenceBehaviour = proc (temperature, (measured, latent)) -> do
particles <- resampleMoveSequentialMonteCarloCl 10 1 resampleSystematic posteriorTemperatureProcess -< measured
returnA -< Result{temperature, measured, latent, particles}

inferenceRMSMCDyn :: Rhine (GlossConcT IO) (LiftClock IO GlossConcT Busy) (Temperature, (Sensor, Pos)) Result
inferenceRMSMCDyn = hoistClSF sampleIOGloss inferenceBehaviour @@ liftClock Busy
where
inferenceBehaviour :: (MonadDistribution m, TimeDomain td, Diff td ~ Double, MonadIO m) => BehaviourF m td (Temperature, (Sensor, Pos)) Result
inferenceBehaviour = proc (temperature, (measured, latent)) -> do
particles <- runPopulationCl nParticles resampleSystematic posteriorTemperatureProcess -< measured
particles <- resampleMoveSequentialMonteCarloDynCl 10 1 (onlyBelowEffectiveSampleSize 5 resampleSystematic) posteriorTemperatureProcess -< measured
returnA -< Result{temperature, measured, latent, particles}

-- | Visualize the current 'Result' at a rate controlled by the @gloss@ backend, usually 30 FPS.
Expand All @@ -336,12 +374,45 @@ mainRhineMultiRate =
>-- keepLast Result{temperature = initialTemperature, measured = zeroVector, latent = zeroVector, particles = []} -@- glossConcurrently -->
visualisationRhine

mainRhineMultiRateRMSMC =
userTemperature
@@ glossClockUTC GlossEventClockIO
>-- keepLast initialTemperature -@- glossConcurrently -->
modelRhine
>-- keepLast (initialTemperature, (zeroVector, zeroVector)) -@- glossConcurrently -->
inferenceRMSMC
>-- keepLast Result{temperature = initialTemperature, measured = zeroVector, latent = zeroVector, particles = []} -@- glossConcurrently -->
visualisationRhine


mainRhineMultiRateRMSMCDyn =
userTemperature
@@ glossClockUTC GlossEventClockIO
>-- keepLast initialTemperature -@- glossConcurrently -->
modelRhine
>-- keepLast (initialTemperature, (zeroVector, zeroVector)) -@- glossConcurrently -->
inferenceRMSMCDyn
>-- keepLast Result{temperature = initialTemperature, measured = zeroVector, latent = zeroVector, particles = []} -@- glossConcurrently -->
visualisationRhine

mainMultiRate :: IO ()
mainMultiRate =
void $
launchGlossThread glossSettings $
flow mainRhineMultiRate

mainMultiRateRMSMC :: IO ()
mainMultiRateRMSMC =
void $
launchGlossThread glossSettings $
flow mainRhineMultiRateRMSMC

mainMultiRateRMSMCDyn :: IO ()
mainMultiRateRMSMCDyn =
void $
launchGlossThread glossSettings $
flow mainRhineMultiRateRMSMCDyn

-- * Utilities

instance MonadDistribution m => MonadDistribution (GlossConcT m) where
Expand Down
118 changes: 118 additions & 0 deletions rhine-bayes/src/Data/MonadicStreamFunction/Bayes.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ module Data.MonadicStreamFunction.Bayes where
import Control.Arrow
import Data.Functor (($>))
import Data.Tuple (swap)
import Debug.Trace

-- transformers

Expand All @@ -12,11 +13,19 @@ import Numeric.Log hiding (sum)

-- monad-bayes
import Control.Monad.Bayes.Population
import Control.Monad.Bayes.Traced

-- dunai
import Data.MonadicStreamFunction
import Data.MonadicStreamFunction.InternalCore (MSF (..))
import Control.Monad.Trans.Class
import Control.Monad.Bayes.Class (MonadDistribution)
import qualified Control.Monad.Bayes.Traced.Static as Static
import Control.Monad.Bayes.Sequential.Coroutine (hoistFirst)
import Control.Monad.Trans.MSF (performOnFirstSample)
import qualified Control.Monad.Bayes.Traced.Dynamic as Dynamic

-- FIXME rename to sequentialMonteCarlo or smc?
-- | Run the Sequential Monte Carlo algorithm continuously on an 'MSF'
runPopulationS ::
forall m a b.
Expand Down Expand Up @@ -48,6 +57,115 @@ runPopulationsS resampler = go
unzip $
(swap . fmap fst &&& swap . fmap snd) . swap <$> bAndMSFs

resampleMoveSequentialMonteCarlo ::
forall m a b.
MonadDistribution m =>
-- (MonadDistribution m, HasTraced t, MonadTrans t) =>
-- | Number of particles
Int ->
-- | Number of MC steps
Int ->
-- | Resampler
(forall x. Population m x -> Population m x) ->
MSF (Static.Traced (Population m)) a b ->
-- MSF (t (Population m)) a b ->
-- FIXME Why not MSF m a (Population b)
MSF m a [(b, Log Double)]
resampleMoveSequentialMonteCarlo nParticles nMC resampler = go . Control.Monad.Bayes.Traced.hoist (spawn nParticles >>) . pure
where
go ::
Monad m =>
Static.Traced (Population m) (MSF (Static.Traced (Population m)) a b) ->
-- Population m (MSF (t (Population m)) a b) ->
MSF m a [(b, Log Double)]
go msfs = MSF $ \a -> do
-- TODO This is quite different than the dunai version now. Maybe it's right nevertheless.
-- FIXME This normalizes, which introduces bias, whatever that means
let bAndMSFs = composeCopies nMC mhStep
$ Control.Monad.Bayes.Traced.hoist resampler
$ flip unMSF a =<< msfs
bs <- runPopulation $ marginal $ fst <$> bAndMSFs
return (bs, go $ snd <$> bAndMSFs)

resampleMoveSequentialMonteCarloDynamic ::
forall m a b.
MonadDistribution m =>
-- (MonadDistribution m, HasTraced t, MonadTrans t) =>
-- | Number of particles
Int ->
-- | Number of MC steps
Int ->
-- | Resampler
(forall x. Population m x -> Population m x) ->
MSF (Dynamic.Traced (Population m)) a b ->
-- MSF (t (Population m)) a b ->
-- FIXME Why not MSF m a (Population b)
MSF m a [(b, Log Double)]
resampleMoveSequentialMonteCarloDynamic nParticles nMC resampler = go . Dynamic.hoist (spawn nParticles >>) . pure
where
go ::
Monad m =>
Dynamic.Traced (Population m) (MSF (Dynamic.Traced (Population m)) a b) ->
-- Population m (MSF (t (Population m)) a b) ->
MSF m a [(b, Log Double)]
go msfs = MSF $ \a -> do
-- TODO This is quite different than the dunai version now. Maybe it's right nevertheless.
-- FIXME This normalizes, which introduces bias, whatever that means
let bAndMSFs = Dynamic.freeze
$ composeCopies nMC Dynamic.mhStep
$ Dynamic.hoist resampler
$ flip unMSF a =<< msfs
bs <- runPopulation $ Dynamic.marginal $ fst <$> bAndMSFs
return (bs, go $ snd <$> bAndMSFs)

-- | Apply a function a given number of times.
composeCopies :: Int -> (a -> a) -> (a -> a)
composeCopies k f = foldr (.) id (replicate k f)

tracePop :: Monad m => String -> Population m a -> Population m a
-- tracePop msg = fromWeightedList . fmap (\pop -> Debug.Trace.traceShow (msg, length pop) pop) . runPopulation
tracePop _ = id

-- resampleMoveSequentialMonteCarlo nParticles nMC resampler = morphS marginal $ runPopulationS nParticles $ freeze . composeCopies nMC mhStep . hoistTrace resampler

-- FIXME see PR re-adding this to monad-bayes
normalize :: Monad m => Population m a -> Population m a
normalize = fromWeightedList . fmap (\particles -> second (/ (sum $ snd <$> particles)) <$> particles) . runPopulation

-- FIXME See PR to monad-bayes


-- | Only use the given resampler when the effective sample size is below a certain threshold
onlyBelowEffectiveSampleSize ::
MonadDistribution m =>
-- | The threshold under which the effective sample size must fall before the resampler is used.
-- For example, this may be half of the number of particles.
Double ->
-- | The resampler to user under the threshold
(forall n . MonadDistribution n => Population n a -> Population n a) ->
-- | The new resampler
(Population m a -> Population m a)
onlyBelowEffectiveSampleSize threshold resampler pop = do
(particles, ess) <- lift $ runWithEffectiveSampleSize pop
let newPop = fromWeightedList $ pure particles
-- This assumes that the resampler does not mutate the m effects, as it should
if ess < threshold then resampler newPop else newPop

-- | Compute the effective sample size of a population from the weights.
--
-- See https://en.wikipedia.org/wiki/Design_effect#Effective_sample_size
runWithEffectiveSampleSize :: Functor m => Population m a -> m ([(a, Log Double)], Double)
runWithEffectiveSampleSize = fmap (id &&& (effectiveSampleSizeKish . map (exp . ln . snd))) . runPopulation
where
effectiveSampleSizeKish :: [Double] -> Double
effectiveSampleSizeKish weights = square (sum weights) / sum (square <$> weights)
square :: Double -> Double
square x = x * x

measureESS :: Monad m => MSF (Population m) a b -> MSF (Population m) a (b, Double)
measureESS = morphGS $ fmap $ \pop -> fromWeightedList $ do
(particles, ess) <- runWithEffectiveSampleSize pop
pure $ map (first (first (, ess))) particles

withESS :: Monad m => Double -> MSF (Population m) (a, Double) b -> MSF (Population m) a b
withESS initESS = feedback initESS . measureESS
45 changes: 44 additions & 1 deletion rhine-bayes/src/FRP/Rhine/Bayes.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
module FRP.Rhine.Bayes where
module FRP.Rhine.Bayes
(module FRP.Rhine.Bayes, module X) where

-- log-domain
import Numeric.Log hiding (sum)
Expand All @@ -13,8 +14,12 @@ import qualified Control.Monad.Trans.MSF.Reader as DunaiReader
-- dunai-bayes
import qualified Data.MonadicStreamFunction.Bayes as DunaiBayes

import Data.MonadicStreamFunction.Bayes as X (onlyBelowEffectiveSampleSize)

-- rhine
import FRP.Rhine
import qualified Control.Monad.Bayes.Traced.Static as Static
import qualified Control.Monad.Bayes.Traced.Dynamic as Dynamic

-- * Inference methods

Expand All @@ -32,6 +37,39 @@ runPopulationCl :: forall m cl a b . Monad m =>
-> ClSF m cl a [(b, Log Double)]
runPopulationCl nParticles resampler = DunaiReader.readerS . DunaiBayes.runPopulationS nParticles resampler . DunaiReader.runReaderS

-- | Run the Resample Move Sequential Monte Carlo algorithm continuously on a 'ClSF'.
resampleMoveSequentialMonteCarloCl :: forall m cl a b . (MonadDistribution m) =>
-- resampleMoveSequentialMonteCarloCl :: forall t m cl a b . (MonadDistribution m, HasTraced t, MonadTrans t) =>
-- | Number of particles
Int ->
-- | Number of MC steps
Int ->
-- | Resampler (see 'Control.Monad.Bayes.Population' for some standard choices)
(forall x . Population m x -> Population m x)
-- | A signal function modelling the stochastic process on which to perform inference.
-- @a@ represents observations upon which the model should condition, using e.g. 'score'.
-- It can also additionally contain hyperparameters.
-- @b@ is the type of estimated current state.
-> ClSF (Static.Traced (Population m)) cl a b
-> ClSF m cl a [(b, Log Double)]
resampleMoveSequentialMonteCarloCl nParticles nMC resampler = DunaiReader.readerS . DunaiBayes.resampleMoveSequentialMonteCarlo nParticles nMC resampler . DunaiReader.runReaderS

resampleMoveSequentialMonteCarloDynCl :: forall m cl a b . (MonadDistribution m) =>
-- resampleMoveSequentialMonteCarloCl :: forall t m cl a b . (MonadDistribution m, HasTraced t, MonadTrans t) =>
-- | Number of particles
Int ->
-- | Number of MC steps
Int ->
-- | Resampler (see 'Control.Monad.Bayes.Population' for some standard choices)
(forall x . Population m x -> Population m x)
-- | A signal function modelling the stochastic process on which to perform inference.
-- @a@ represents observations upon which the model should condition, using e.g. 'score'.
-- It can also additionally contain hyperparameters.
-- @b@ is the type of estimated current state.
-> ClSF (Dynamic.Traced (Population m)) cl a b
-> ClSF m cl a [(b, Log Double)]
resampleMoveSequentialMonteCarloDynCl nParticles nMC resampler = DunaiReader.readerS . DunaiBayes.resampleMoveSequentialMonteCarloDynamic nParticles nMC resampler . DunaiReader.runReaderS

-- * Short standard library of stochastic processes

-- | White noise, that is, an independent normal distribution at every time step.
Expand Down Expand Up @@ -83,3 +121,8 @@ wienerVaryingLogDomain ::
(MonadDistribution m, Diff td ~ Double) =>
BehaviourF m td (Diff td) (Log Double)
wienerVaryingLogDomain = wienerVarying >>> arr Exp

withESS :: Monad m => Double -> ClSF (Population m) cl (a, Double) b -> ClSF (Population m) cl a b
withESS initESS = DunaiReader.readerS . DunaiBayes.withESS initESS . (arr assoc >>>) . DunaiReader.runReaderS
where
assoc ((ti, a), td) = (ti, (a, td))