A Brief, High-Level Intro to Amortized VI

In this post I will give a very high-level introduction to the concept of amortized vartiational inference1. Before diving in, let me briefly describe the setting first, as the best way to understand amortized variational inference (in my opinion) is in the context of regular variational inference (VI).

Quick background on VI

Let’s assume that we have some latent variable model, such that for every data-point $x$ there is a corresponding local latent variable $z$. VI is applicable more generally than this setting, but this is scenario provides a good intuition for how and why VI works. We are interested in the posterior distribution of the latent variable conditioned on the observations $p(z | x)$, which is crucial to “learning” in such a model. For simple models, we can apply Bayes’ rule directly to get this distribution. For more complicated models, this is unfortunately intractable.

The basic premise in VI is to approximate this intractable distribution with a simpler one that we know how to evaluate and handle. We refer to this simpler distribution as the approximate posterior, or variational approximation, and denote it $q$. Typically, this will be from some parametrized family $\mathcal{Q}$, e.g., Gaussian. We can then optimize the parameters of this simple distribution (sometimes referred to as variational parameters) w.r.t. an appropriate objective function (this is called the ELBO, which is a lower bound on the log-marginal likelihood of the observed data). This is equivalent to finding the distribution in the specified family closest (in the KL-divergence sense) to the true posterior 2. The beauty of this is that we have substituted intractable posterior inference with optimization, which is something we can generally do at scale (especially with the introduction of stochastic VI3).

Practical issues with VI

What does this mean, practically? For a Gaussian posterior approximation, we would introduce a mean and variance parameter for every observation, and optimize all of these jointly. Two problems to notice with this procedure. The first is the number of parameters we need to optimize grows (at least) linearly with the number of observations. Not ideal for massive data-sets. The second is that if we get new observations, or have test observations we would like to perform inference for, it is not clear how this fits in to our framework. In general, for new observations we would need to re-run the optimization procedure.

Amortizing Posterior Inference

Amortized VI is the idea that instead of optimizing a set of free parameters, we can introduce a parameterized function that maps from observation space to the parameters of the approximate posterior distribution. In practice, we might (for example) introduce a neural network that accepts an observation as input, and outputs the mean and variance parameter for the latent variable associated with that observation 45. We can then optimize the parameters of this neural network instead of the individual parameters of each observation. That’s it — that’s what amortized VI is.

What’s great about this?

Notice that this directly addresses the issues mentioned earlier. First, the number of variational parameters is now constant w.r.t. to the data size! In our example, we need only specify the parameters of the neural network, and that is not dependent in any way on the number of observations we have. Second, for a new observation, all we need to do is pass it through the network, and voila, we have an approximate posterior distribution over its associated latent variables! At the constant cost of a forward pass through a network!! These gains are the source of the term amortized.

Is this a free win?

No, of course not — that’s almost always a rhetorical question. A common misconception is that since this distribution is generated by a neural network, it is somehow more expressive than the original framing. Quite the opposite; the amortized posterior approximation is less expressive than its free-form counterpart. The approximate posterior is still Gaussian, but now there is an additional constraint imposed by requiring that all the variational parameters lie in the range of the network. Within a specific parametric family, nothing is more general than freely optimizing the variational parameters. This cost is known as the amortization gap6. For a network with infinite capacity, this gap goes away, and we can be (only) as good as the free-form optimization setting. Of course, that is not the case in any practical implementations, as all networks have finite capacity.

Regardless of this drawback, amortized VI represents a significant advancement in probabilistic machine learning, and has opened the door to massive improvements in both modelling and inference.

Originally a Quora answer.

References



  1. Zhang, Cheng, et al. Advances in Variational Inference. 2018 ↩︎

  2. Wainwright, Martin and Jordan, Michael. Graphical Models, Exponential Families, and Variational Inference. 2008 ↩︎

  3. Hoffman, Matthew, et al. Stochastic Variational Inference. 2013 ↩︎

  4. Kingma, Deidrik and Welling, Max. Auto-Encdoing Variational Bayes. 2013 ↩︎

  5. Rezende, Danilo, et al. Stochastic Backpropagation and Approximate Inference in Deep Generative Models. 2014 ↩︎

  6. Cremer, Chris, et al. Inference Suboptimality in Variational Autoencoders. 2018 ↩︎

Avatar
Jonathan Gordon
Machine Learning PhD Student

My research interests include probabilistic machine learning, deep learning, and approximate Bayesian inference.