This guide shows, in a series of vignettes that are monotonically increasing in complexity, how to use the pyro probabilistic programming language.
Introducing our first random variable, using the pyro.sample
primitive
Introducing the concept of observed values, via the obs=
argument to pyro.sample
Introducing pyro.plate
, a primitive for marking conditionally independent variables
Introducing poutine.trace
, our first "effect handler"
Introducing the internal structure of poutine.trace
, and how that structure can be used to calculate the
log probability sum and also create a graphical model plate diagram
Combining observed values with calculating the log probability sum of our distribution
Demonstrating two separate random variables in one model, and how we can sample the individual values of each
using poutine.trace
Using poutine.condition
to condition on the value of one random variable before we sample from our posterior,
and how this affects the log probability sum
Demonstrating how to look at the gradient at a given random variable in our model
Using this gradient to write a simple gradient descent algorithm to the find the MAP estimate of a random variable in our model
Showing how the gradient descent algorithm laid out in step 10 fails if we have much more complicated, deeply nested hierarchical priors
Since we cannot optimize our loss function, the next best thing to try is an MCMC algorithm to sample from our posterior
The MCMC sampler works great but it is maybe too slow, so the next best thing to try is SVI.
For the sake of simplicity we return to the simple model from step 10 and show how to run SVI on it using a manually constructed guide
Demonstrating how to use the AutoNormal
class to automatically construct a guide function for us, and we
see the results are the same as our manually constructed guide.
Demonstrating nested plates, and the peculiarity of indexing into nested plates in pyro. We show two equivalent ways to nest plates in pyro.
It will be very difficult to understand the next vignette unless you understand the concept of "batch" and "event" dimensions.
This concept is explained very well by this blog post:
https://ericmjl.github.io/blog/2019/5/29/reasoning-about-shapes-and-probability-distributions/
(you only need to read up to the "Other Scenarios" section)
Demonstrating the difference between batch and event dimensions with an example model that has to deal with both.
Introducing a model with a discrete latent variable, and how we can use TraceEnum_ELBO to marginalize over it, and infer_discrete to create a classifier from our trained model.
Demonstrating that some local optima are better than others, and how our model from step_17 is sensitive to its initialization values
Demonstrating how to plot the ELBO loss curve for a model.
Inspecting this curve can give you insight into if you need to run for more SVI steps and if your model is converging.
A contributor has made a completely stripped down implementation of the pyro framework that is only a few hundred lines of code and pretty closely emulates what the full version of pyro does
Reading through this implementation can give you deep insight into how pyro itself works and I highly recommend it for advanced users: