Variational inference, the art of approximate sampling

In the spirit of looking at fancy word topics, this post is about variational inference. Suppose you granted me one super power and I chose the ability to sample from any distribution in a fast and accurate way. Now, you might think that’s a crappy super-power, but that basically enables me to fit any model I want and provide uncertainty estimates.

To make the problem concrete, lets suppose you are trying to sample from a distribution \(p(x)\). To make matters worse, you only know \(f(x) = Cp(x)\) up to some constant \(C>0\). Well that seems a little contrived, right? But let’s say I use the following Bayesian linear model (see here for an introduction): \[ \begin{aligned} y | \beta &\sim \text{Laplace}(X\beta, 1)\\ \beta &\sim N(0,1). \end{aligned} \] I can calculate the posterior up to constants quite easily by Bayes rule: \[ p(\beta | y, X) \propto p(y | X, \beta)p(\beta | X) = \frac{1}{2}e^{-\sum_{i=1}^n|y_i - (X\beta)_i|}\frac{1}{\sqrt{2\pi}}e^{-\frac{1}{2}\beta^T\beta}. \] Looks easy enough, but trying to integrate the above function (over \(\beta\)) using trapezium rule in \(100\) dimensions is pretty difficult.

Generally there are two approaches to sampling \(p(x)\). One is Markov chain Monte Carlo which is what packages like Stan do. Instead today we are going to look at approximating \(p(x)\), which is called variational inference. The general goal is to have an easy to sample distributions with densities \(q_\lambda(x)\) indexed by some parameters \(\lambda\), and pick the best \(\lambda^*\) so that \(q_{\lambda^*}(x) \approx p(x)\). So now the natural question is, how do we measure the approximation accuracy of \(q_\lambda\)?

Kullback - Leibler divergence

Suppose you’ve outsourced your sampling needs to a dodgy sampling company ran by me. You want to sample some \(N(0,1)\) random variables, but unfortunately I only have Laplace\((0, 2^{-1/2})\) random variables (which has the same mean and variance as the normal). I send you these anyway and tell you they are normally distributed in an attempt to fool you.

Our national sport, fooling people

Our national sport, fooling people

After a while you notice that my samples have a lot of outliers. That’s because the Laplace distribution has heavier tails. So you plot out what you got from me against a normal density and see this:

library(rmutil)
tibble(received = rlaplace(n=100, s=1/sqrt(2))) %>% 
  ggplot(aes(received)) +
  geom_histogram(aes(y=..density..), bins = 20) +
  stat_function(fun = dnorm, args=list(mean=0, sd=1), colour='red') +
  labs(x='', y='Density')

You notice that it looks awfully like I’m sending you Laplace distribution samples but before you go to your boss, you want some hard statistical proof of this trickery. So you do a likelihood ratio test to test \(H_0\): the samples come from a \(N(0,1)\) distribution vs \(H_1\): the samples come from a Laplace\((0, 2^{-1/2})\) distribution. This is a classical hypothesis test, take \[ L = \log\left(\prod_{i=1}^n \frac{q(x_i)}{p(x_i)}\right) = \sum_{i=1}^n \log\left(\frac{q(x_i)}{p(x_i)}\right), \] where \(p\) is a normal density and \(q\) is a Laplace density, and reject the null hypothesis if \(L > c\) for some threshold \(c\). So with my dodgy samples, the value of \(L\) you get is

samples <- rlaplace(n=100, s=1/sqrt(2))
L <-  prod(dlaplace(samples, 0, 1/sqrt(2)) / dnorm(samples, 0, 1)) %>% log
print(L)
#> [1] 10.33139

So it seems pretty likely that I’m tricking you and you go to your boss and complain.

Notice that this testing isn’t symmetric. Suppose the roles of the normal and Laplace distributions were reversed, i.e. I am giving you normal distribution samples, where as you wanted Laplace. What would the test say then?

samples <- rnorm(n=100)
L <-  prod(dnorm(samples, 0, 1) / dlaplace(samples, 0, 1/sqrt(2))) %>% log
print(L)
#> [1] 0.5965472

Well now the number we get is lower than before? This makes sense when you think about it. If give you Laplace and you expect normal, there is going to be a lot of samples that have very low probability of appearing. When the roles are reversed and I’m giving you normal distribution when you expect Laplace, then you expect the tails to have more samples but they don’t come. The latter event has a much higher probability than the former event.

So your boss calls me up and tells me about your test then complains to me about the samples you received (I gave you Laplace when you wanted normal). I grudgingly admit I have been tricking you but promise to resolve the issue.

No more tricks please

No more tricks please

Unfortunately, normal samplers are hard to come by and expensive, so I go back to my old tricks, but now I know I’m being watched. I go out shopping for other distribution samplers that are cheaper. Since I know you are doing a likelihood test, I know that if I send you samples from distribution with density \(q\), your test statistic will be (at least asymptotically) \[ \text{KL}(q||p) = \int q(x) \log\left(\frac{q(x)}{p(x)}\right) \, dx. \] which is the Kullback-Leibler divergence. So in my shopping, what I should do is try to minimise the Kullback-Leibler divergence.

ELBO: Evidence lower bound

Back to the variational inference, now we know we would like to select \[ \lambda^* = \text{argmin}_\lambda \text{KL}(q_\lambda||p) \] but we only know \(p(x)\) up to some multiplicative constant, \(p(x) = Cf(x)\). Well let’s start by writing out the KL divergence: \[ \begin{aligned} \text{KL}(q_\lambda||p) &= \int q_\lambda(x) \log\left(\frac{q_\lambda(x)}{p(x)}\right) \, dx.\\ &=\int q_\lambda(x) \log q_\lambda(x) \, dx - \int q_\lambda(x) \log p(x) \, dx \\ &= \int q_\lambda(x) \log q_\lambda(x) \, dx - \int q_\lambda(x) \log f(x) \, dx + \log(C). \end{aligned} \] So minimising the KL divergence is equivalent to maximising the evidence lower bound (ELBO) given by \[ \text{ELBO}(q_\lambda || f) = \int q_\lambda(x) \log q_\lambda(x) \, dx - \int q_\lambda(x) \log f(x) \, dx. \] Why is this easier? Well remember that we can easily sample from the distribution, \(X^{(\lambda)}\), whose density is \(q_\lambda\). So and easy way of doing the integral above is to get i.i.d. samples \(x^{(\lambda)}_1, ..., x^{(\lambda)}_n\) of \(X^{(\lambda)}\) and then approximate ELBO by \[ \text{ELBO-approx}(q_\lambda || f) = \frac{1}{n}\sum_{i=1}^n \log q_\lambda(x_i^{(\lambda)}) - \frac{1}{n}\sum_{i=1}^n \log f(x_i^{(\lambda)}). \] We can then numerically minimise ELBO-approx to find the optimal \(\lambda\).

Here comes the ELBO

Here comes the ELBO

So let’s go back to the early example and do it in one dimension: \[ f(\beta) = e^{-\sum_{i=1}^n|y_i - x_i\beta| -\frac{1}{2}\beta^2} \] where the \(x\) and \(y\) are given. We will simply generate these randomly, because, why not. As our approximating distribution, let’s take \(N(\mu, \sigma^2)\). Using the normal distribution has the added benefit that we can just simulate one set of \(N(0,1)\) variables \(Z_1\), …, \(Z_n\) and then scale them \(\mu + \sigma Z_i\) to get \(N(\mu, \sigma)\) variables.

y <- rnorm(100)
x <- rnorm(100)
f <- function(beta) {
  exp(-sum(abs(y - x * beta)) - .5 * beta^2)
}
z <- rnorm(500)
elbo_approx <- function(mu, sigma, z) {
  if(max(sigma <= 0)){
    Inf
  } else{
    mean( log(dnorm(mu + sigma * z, mu, sigma)) - log(f(mu + sigma * z)) )
  }
}
lambda_min <- optim(c(0, 1), function(x) elbo_approx(x[1], x[2], z))
print(lambda_min)
#> $par
#> [1] -0.09681142  0.05594696
#> 
#> $value
#> [1] 378.2323
#> 
#> $counts
#> function gradient 
#>       83       NA 
#> 
#> $convergence
#> [1] 0
#> 
#> $message
#> NULL

We can see that the normal distribution doesn’t approximate this distribution too well because the minimal KL divergence is still quite high. Since this is one dimensional, we can integrate the function numerically and see visually how the fit is.

int_f <- integrate(function(x) map_dbl(x, f), -Inf, Inf)
p <- function(beta) map_dbl(beta, f) / int_f$value
ggplot(tibble(x = seq(-1, 1, .1)), aes(x)) +
  stat_function(fun = p, aes(colour='Given distribution')) +
  stat_function(fun = dnorm, args = list(mean=lambda_min$par[1], sd=lambda_min$par[2]), aes(colour='Aproximation')) +
  labs(colour='',
       title='Variational inference using normal distribution',
       x='', y='Density')

Some tools

If you want to do variational inference seriously, there are several tools available: Stan, PyMC3 and Edward to name a few. The latter two are python projects which use Theano and Tensorflow as back-ends respectively. Here is an example of the above problem Edward.

from edward.models import *
import edward as ed
import tensorflow as tf
import numpy as np
import pandas as pd
# Make up some data
n, p = 100, 5
p=5
beta = np.array([0] * p)
X = np.random.normal(size=[n, p]).astype(np.float32)
y_obs = np.matmul(X, beta) + np.random.normal(size=n)
# Edward model
beta = Normal(tf.zeros(p), tf.ones(p))
y = Laplace(ed.dot(X, beta), tf.ones(n))
qbeta = Normal(tf.Variable(tf.zeros(p)), tf.Variable(tf.ones(p)))
inference = ed.KLqp({beta: qbeta}, data={y: y_obs})
#> /Users/bati/anaconda3/lib/python3.6/site-packages/edward/util/random_variables.py:52: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
#>   not np.issubdtype(value.dtype, np.float) and \
inference.run()
#> 
   1/1000 [  0%]                                ETA: 377s | Loss: 186.040
  10/1000 [  1%]                                ETA: 38s | Loss: 178.261 
  30/1000 [  3%]                                ETA: 13s | Loss: 187.190
  50/1000 [  5%] █                              ETA: 8s | Loss: 192.530 
  60/1000 [  6%] █                              ETA: 6s | Loss: 156.083
  80/1000 [  8%] ██                             ETA: 5s | Loss: 159.027
 100/1000 [ 10%] ███                            ETA: 4s | Loss: 158.725
 110/1000 [ 11%] ███                            ETA: 3s | Loss: 153.924
 120/1000 [ 12%] ███                            ETA: 3s | Loss: 165.243
 140/1000 [ 14%] ████                           ETA: 3s | Loss: 158.423
 150/1000 [ 15%] ████                           ETA: 2s | Loss: 153.944
 170/1000 [ 17%] █████                          ETA: 2s | Loss: 158.504
 180/1000 [ 18%] █████                          ETA: 2s | Loss: 156.061
 190/1000 [ 19%] █████                          ETA: 2s | Loss: 160.611
 200/1000 [ 20%] ██████                         ETA: 2s | Loss: 158.066
 210/1000 [ 21%] ██████                         ETA: 2s | Loss: 158.230
 220/1000 [ 22%] ██████                         ETA: 2s | Loss: 155.630
 230/1000 [ 23%] ██████                         ETA: 2s | Loss: 157.326
 240/1000 [ 24%] ███████                        ETA: 2s | Loss: 159.261
 250/1000 [ 25%] ███████                        ETA: 2s | Loss: 154.861
 260/1000 [ 26%] ███████                        ETA: 1s | Loss: 157.882
 270/1000 [ 27%] ████████                       ETA: 1s | Loss: 155.788
 280/1000 [ 28%] ████████                       ETA: 1s | Loss: 159.022
 290/1000 [ 28%] ████████                       ETA: 1s | Loss: 157.979
 300/1000 [ 30%] █████████                      ETA: 1s | Loss: 161.509
 310/1000 [ 31%] █████████                      ETA: 1s | Loss: 157.220
 320/1000 [ 32%] █████████                      ETA: 1s | Loss: 164.703
 330/1000 [ 33%] █████████                      ETA: 1s | Loss: 160.481
 340/1000 [ 34%] ██████████                     ETA: 1s | Loss: 160.078
 350/1000 [ 35%] ██████████                     ETA: 1s | Loss: 158.368
 360/1000 [ 36%] ██████████                     ETA: 1s | Loss: 154.899
 370/1000 [ 37%] ███████████                    ETA: 1s | Loss: 159.733
 380/1000 [ 38%] ███████████                    ETA: 1s | Loss: 154.945
 390/1000 [ 39%] ███████████                    ETA: 1s | Loss: 165.385
 400/1000 [ 40%] ████████████                   ETA: 1s | Loss: 158.438
 410/1000 [ 41%] ████████████                   ETA: 1s | Loss: 168.135
 420/1000 [ 42%] ████████████                   ETA: 1s | Loss: 159.882
 430/1000 [ 43%] ████████████                   ETA: 1s | Loss: 158.214
 440/1000 [ 44%] █████████████                  ETA: 1s | Loss: 158.050
 450/1000 [ 45%] █████████████                  ETA: 1s | Loss: 156.203
 470/1000 [ 47%] ██████████████                 ETA: 1s | Loss: 156.804
 480/1000 [ 48%] ██████████████                 ETA: 1s | Loss: 158.771
 490/1000 [ 49%] ██████████████                 ETA: 1s | Loss: 162.193
 500/1000 [ 50%] ███████████████                ETA: 1s | Loss: 160.835
 510/1000 [ 51%] ███████████████                ETA: 0s | Loss: 155.639
 520/1000 [ 52%] ███████████████                ETA: 0s | Loss: 156.448
 530/1000 [ 53%] ███████████████                ETA: 0s | Loss: 157.193
 540/1000 [ 54%] ████████████████               ETA: 0s | Loss: 157.514
 550/1000 [ 55%] ████████████████               ETA: 0s | Loss: 155.786
 560/1000 [ 56%] ████████████████               ETA: 0s | Loss: 159.457
 580/1000 [ 57%] █████████████████              ETA: 0s | Loss: 157.397
 590/1000 [ 59%] █████████████████              ETA: 0s | Loss: 160.092
 600/1000 [ 60%] ██████████████████             ETA: 0s | Loss: 157.246
 610/1000 [ 61%] ██████████████████             ETA: 0s | Loss: 156.642
 620/1000 [ 62%] ██████████████████             ETA: 0s | Loss: 155.032
 630/1000 [ 63%] ██████████████████             ETA: 0s | Loss: 156.192
 640/1000 [ 64%] ███████████████████            ETA: 0s | Loss: 156.774
 650/1000 [ 65%] ███████████████████            ETA: 0s | Loss: 158.376
 660/1000 [ 66%] ███████████████████            ETA: 0s | Loss: 155.603
 670/1000 [ 67%] ████████████████████           ETA: 0s | Loss: 153.842
 680/1000 [ 68%] ████████████████████           ETA: 0s | Loss: 159.044
 690/1000 [ 69%] ████████████████████           ETA: 0s | Loss: 154.874
 700/1000 [ 70%] █████████████████████          ETA: 0s | Loss: 156.959
 710/1000 [ 71%] █████████████████████          ETA: 0s | Loss: 161.028
 720/1000 [ 72%] █████████████████████          ETA: 0s | Loss: 158.404
 730/1000 [ 73%] █████████████████████          ETA: 0s | Loss: 165.261
 740/1000 [ 74%] ██████████████████████         ETA: 0s | Loss: 157.272
 750/1000 [ 75%] ██████████████████████         ETA: 0s | Loss: 159.367
 760/1000 [ 76%] ██████████████████████         ETA: 0s | Loss: 169.166
 770/1000 [ 77%] ███████████████████████        ETA: 0s | Loss: 158.112
 780/1000 [ 78%] ███████████████████████        ETA: 0s | Loss: 155.058
 790/1000 [ 79%] ███████████████████████        ETA: 0s | Loss: 159.372
 800/1000 [ 80%] ████████████████████████       ETA: 0s | Loss: 156.801
 810/1000 [ 81%] ████████████████████████       ETA: 0s | Loss: 154.721
 820/1000 [ 82%] ████████████████████████       ETA: 0s | Loss: 155.728
 840/1000 [ 84%] █████████████████████████      ETA: 0s | Loss: 156.823
 860/1000 [ 86%] █████████████████████████      ETA: 0s | Loss: 160.031
 880/1000 [ 88%] ██████████████████████████     ETA: 0s | Loss: 156.596
 890/1000 [ 89%] ██████████████████████████     ETA: 0s | Loss: 157.346
 900/1000 [ 90%] ███████████████████████████    ETA: 0s | Loss: 155.012
 910/1000 [ 91%] ███████████████████████████    ETA: 0s | Loss: 156.418
 920/1000 [ 92%] ███████████████████████████    ETA: 0s | Loss: 155.654
 930/1000 [ 93%] ███████████████████████████    ETA: 0s | Loss: 155.856
 940/1000 [ 94%] ████████████████████████████   ETA: 0s | Loss: 155.775
 950/1000 [ 95%] ████████████████████████████   ETA: 0s | Loss: 156.329
 960/1000 [ 96%] ████████████████████████████   ETA: 0s | Loss: 153.975
 970/1000 [ 97%] █████████████████████████████  ETA: 0s | Loss: 153.938
 980/1000 [ 98%] █████████████████████████████  ETA: 0s | Loss: 154.555
 990/1000 [ 99%] █████████████████████████████  ETA: 0s | Loss: 163.651
1000/1000 [100%] ██████████████████████████████ Elapsed: 1s | Loss: 156.488
sess = ed.get_session()
print('Mean: %s' % sess.run(qbeta.mean()))
#> Mean: [ 0.17253913  0.1097104   0.08593578 -0.01670811 -0.01817157]
print('Sdev: %s' % sess.run(qbeta.stddev()))
#> Sdev: [0.05179016 0.063154   0.14621048 0.08485013 0.12385876]

comments powered by Disqus