Sources: Matthew N. Bernstein blog, Deep Learning Book, Ian Good Fellow
The term inference usually refers to computing the probability distribution of a set of variable when provided another set of variables. We usually are interested in computing \(p(z|v)\), in a context of latent variable or multitask learning when the task is defined by a vector \(z\). The challenge lays in the difficulty of computing \(p(z|v)\) or the expected value with respect to it \(E_{z∼p(z|v)}\)
Intractable inference in deep learning arises from the connections between the activations generated from the multitude of layers that create either mutual ancestors or large cliques of activations. In order to solve the intractable inference problem we may approach it as an optimization problem and derive approximate inference algorithms that will approximate the underlying exact inference optimization.
To define an optimization problem we can assume that we have an observed data \(v\) that is a realization of a random variable \(V\). We put forward the existence of another random variable \(Z\) and \(V\) and \(Z\) are distributed according a joint distribution \(p(X,Z;\theta)\) where \(\theta\) parameterizes the distribution. The joint distribution indicates that \(Z\) and \(V\) are strongly correlated. Suppose \(m_z,m_v\) are the means of \(z\) and \(v\) and are known, then we might be interested in the \(\sigma_{zv}\) the covariance that measures the connection of \(z\) and \(v\). \[\sigma_{zv}=E[(z-m_z)(v-m_v)] \tag{1}\]
Hence, to compute \(\sigma_{zv}\) it is not enough to know the probability of \(z\) and the probability of each \(v\) but the joint probability of each pair \(z\) and \(v\).
Suppose we run separate 2 experiments, the covariance between experiments for all pairs of \(z_i\) and \(v_j\) then equation (Equation 1)looks like:\[ \begin{aligned} \sigma_{12}=E([(z-m_1)(v-m_2)]) \\ \sigma_{12}=\sum_{all}\sum_{ij}p_{ij}(z_i-m_1)(y_j-m_2) \end{aligned} \tag{2}\] But \(v\) is a realization from \(V\) and not from \(Z\) and therefor \(z\) remains “latent” i.e. not observed. We might be interested in either computing the posteriour distribution \(p(Z|V;\theta)\) given a fixed \(\theta\) or finding the maximum likelihood \(argmaxv_{\theta}l(\theta)\) where \(l(\theta)\) is the log-likelihood function given an unknown \(\theta\) defined by: \[ l(\theta):=\log p(v;\theta)=\log\int_{z}p(v,z;\theta)dz \tag{3}\] We could envisage to compute the log-probability of the observed variable \(v\) \(\log p(v;\theta)\) but sometimes it is is to costly to marginalize out \(h\) and this computation becomes difficult. Instead we can compute the evidence lower bound (ELBO) or variatonal free energy \(\mathcal{L}(v,\theta,q)\) on \(\log(v;\theta)\). The evidence, in evidence lower bound, is the likelihood evaluated at a fixed \(\theta\) : \[ evidence:\log p(v;\theta) \tag{4}\]
Hence, if we have approximated \(\theta\) well enough through our optimization approach we would expect that the marginal probability of the observed variable \(v\) will be high. The evidence thus quantifies the quality of the approximation of \(p_{model}\) parameterized by \(\theta\) of \(p_{data}\) i.e. \(p_{model}(z|v;\theta)\approx p_{data}(z|v)\).
If we consider that \(Z\) follows an arbitrary probability distribution \(q\) and that the joint distribution \(p(v,z;\theta):=p(v|z;\theta)q(z)\), then the evidence lower bound is the lower bound on the evidence that makes use of \(q\). Concretely : \[ \begin{aligned} \log p(v;\theta)\ge \mathbb{E}_{z\sim q}\left[ \log \frac{p(v;z;\theta)}{q(z)} \right] \end{aligned} \tag{5}\] where ELBO is the right-hand side of Equation 5
\[ \begin{aligned} ELBO:= \mathbb{E}_{z\sim q}\left[ \log \frac{p(v;z;\theta)}{q(z)} \right] \end{aligned} \tag{6}\]
The gap between the the evidence and the ELBO is the Kullback-Leibler divergence between \(p(z|v;\theta)\). \[ \begin{aligned} D_{KL}(q(z)|| p(z|v;\theta)) \end{aligned} \tag{7}\] This lays the basis of the approximation approach called variational inference where we learn to infer \(q\) through optimization algorithm. As long as \(z\) is continuous, we can back-propagate through samples of \(z\) dawn from \(q(z|v)=q(z;f(v;\theta))\) to obtain a gradient with respect to \(\theta\) in order to maximize \(\mathcal{L}(v,\theta,q)\). We can write \(\mathcal{L}(v,\theta,q)\) as:
\[ \begin{aligned} \mathcal{L}(v,\theta,q)=\log p(v;\theta) - D_{KL}(q(z|v)||p(z|v;\theta)) \end{aligned} \tag{8}\]
where \(q\) is an arbitrary probability distribution over \(z\). The difference between the expectation \(\log p(v)\) and \(\mathcal{L}(v,\theta,q)\) is given by the KL divergence that is always positive. We can conclude that \(\mathcal{L}\) has at most the same value as the desired log-probability. If the two are equal \(q\) has the same distribution as \(p(z|v)\)
We can rearrange \(\mathcal{L}\) into:
\[ \begin{aligned} \mathcal{L}(v,\theta,q)=\log p(v;\theta) - D_{KL}(q(z|v)||p(z|v;\theta)) \\ =\log p(v;\theta) - \mathbb{E}_{z\sim q}\log \frac{q(z|v)}{p(z|v)} \\ =\log p(v;\theta) - \mathbb{E}_{z\sim q}\log \frac{q(z|v)}{\frac{p(z,v;\theta)}{p(v;\theta)}} \\ =\log p(v;\theta) - \mathbb{E}_{z\sim q} \left[\log q(z|v)-\log p(z,v;\theta) + \log p(v;\theta) \right] \\ =-\mathbb{E}_{z\sim q} \left[\log q(z|v) - \log p(z,v;\theta) \right] \end{aligned} \tag{9}\]
Hence a more canonical definition of the evidence lower bound can be defined as: \[ \begin{aligned} \mathcal{L}(v,\theta,q)=\mathbb{E}_{z\sim q} \left[\log p(z,v)\right]+ H(q) \end{aligned} \tag{10}\]
For an appropriate choice of \(q,\mathcal{L}\) is tractable. For \(q(z|v)\) that approximates \(p(z|v)\) better, the lower bound \(\mathcal{L}\) is closer to \(\log p(v)\). When \(q(z|v)=p(z|v)\), the approximation is perfect, hence, \(\mathcal{L}=\log p(v;\theta)\). Thus, we can think of variational inference as the procedure of finding \(q\) that maximizes \(\mathcal{L}\).