visit
This paper is available on arxiv under CC 4.0 license.
Authors:
(1) Sam Bowyer, Equal contribution, Department of Mathematics and [email protected];
(2) Thomas Heap, Equal contribution, Department of Computer Science University of Bristol and [email protected];
(3) Laurence Aitchison, Department of Computer Science University of Bristol and [email protected].
Importance sampling is a popular technique in Bayesian inference: by reweighting samples drawn from a proposal distribution we are able to obtain samples and moment estimates from a Bayesian posterior over some n latent variables. Recent work, however, indicates that importance sampling scales poorly — in order to accurately approximate the true posterior, the required number of importance samples grows is exponential in the number of latent variables (Chatterjee and Diaconis 2018). Massively parallel importance sampling works around this issue by drawing K samples for each of the n latent variables and reasoning about all Kn combinations of latent samples. In principle, we can reason efficiently over Kn combinations of samples by exploiting conditional independencies in the generative model. However, in practice this requires complex algorithms that traverse backwards through the graphical model, and we need separate backward traversals for each computation (posterior expectations, marginals and samples). Our contribution is to exploit the source term trick from physics to entirely avoid the need to hand-write backward traversals. Instead, we demonstrate how to simply and easily compute all the required quantities — posterior expectations, marginals and samples — by differentiating through a slightly modified marginal likelihood estimator.
Importance weighting allows us to reweight samples drawn from a proposal in order to compute expectations of a different distribution, such as a Bayesian posterior. However, importance weighting breaks down in larger models. Chatterjee and Diaconis (2018) showed that the number of samples required to accurately approximate the true posterior scales as exp (DKL (P (z|x)||Q (z))), where P (z|x) is the true posterior over latent variables, z, given data x, and Q (z) is the proposal. Problematically, we expect the KL divergence to scale with n, the number of latent variables. Indeed, if z is composed of n latent variables, and P (z|x) and Q (z) are IID over those n latent variables, then the KL-divergence is exactly proportional to n. Thus, we expect the required number of importance samples to be exponential in the number of latent variables, and hence we expect accurate importance sampling to be intractable in larger models.
To resolve this issue we use a massively parallel importance sampling scheme that in effect uses an exponential number of samples to compute posterior expectations, marginals and samples (Kuntz, Crucinio, and Johansen 2023; Heap and Laurence 2023). This involves drawing K samples of each of the n latent variables from the proposal, then individually reweighting all Kn combinations of all samples of all latent variables. While reasoning about all Kn combinations of samples might seem intractable, we should in principle be able to perform efficient computations by exploiting conditional independencies in the underlying graphical model.
However, many computations that are possible in principle are extremely complex in practice, and that turns out to be the case here. We noticed that we could perhaps perform this reasoning over Kn latent variables using methods from the discrete graphical model literature. This turned out to be less helpful than we had hoped because these algorithms involve highly complex backward traversals of the generative model. Worse, different traversals are needed for computing posterior expectations, marginals and samples, making a general implementation challenging. Our contribution is to develop a much simpler approach to computing posterior expectations, marginals and samples, which entirely avoids the need to explicitly write backwards computations. Specifically, we show that posterior expectations, marginals and samples can be obtained simply by differentiating through (a slightly modified) forward computation that produces an estimate of the marginal likelihood. The required gradients can be computed straightforwardly using modern autodiff, and the resulting implicit backward computations automatically inherit potentially complex optimizations from the forward pass.