Exact Inference


Variable Elimination

Let $p(x_1,...,x_n)$ be a probability distribution. A marginal of a probability distribution $p$ is the function obtained by fixing the value of one (or more) of the random variables and summing out over the remaining variables. For example, the marginal function, $p_i(x_i)$, corresponding to the $i^{th}$ variable is given by \begin{eqnarray} p_i(x_i) = \sum_{x' \text{ st. } x'_i = x_i} p(x'_1,...,x'_n). \end{eqnarray}

Computing the marginals of a given probability distribution is one type of statistical inference and is, in general, a computationally expensive operation. If $X_1,\ldots,X_n$ take values in the set $\{1,\ldots,k\}$, then the marginal $p_i(x_i)$ is computed by summing over all $k^{n-1}$ different possible assignments to the other variables for each of the $k$ values of $x_i$. However, when the probability distribution has some additional structure, marginal distributions can sometimes be computed with significantly fewer operations.

As an example, suppose $p(x_1,x_2,x_3)$ is a probability distribution over three random variables that take values in the set $\{1,...,k\}$. Further, suppose that $p$ can be written as a product of functions as follows. \begin{eqnarray} p(x_1,x_2,x_3) = q_{12}(x_1,x_2)q_{13}(x_1,x_3) \end{eqnarray} Now, consider computing the marginal of $p$ corresponding to the variable $x_1$. \begin{eqnarray} p_1(x_1) = \sum_{x_2} \sum_{x_3} p(x_1,x_2, x_3) \end{eqnarray} As $x_2$ and $x_3$ can each take one of $k$ different values, this summation contains $k^2$ distinct terms for each fixed value of $x_1$. However, if we exploit the observation that $p$ can be written as a product, we can rewrite the summation as \begin{eqnarray} p_1(x_1) & = & \sum_{x_2} \sum_{x_3} q_{12}(x_1,x_2)q_{13}(x_1,x_3)\\ & = & \sum_{x_2} q_{12}(x_1,x_2) \Big[\sum_{x_3} q_{13}(x_1,x_3)\Big] \end{eqnarray} which only requires summing $2k$ distinct terms for each fixed value of $x_1$. To see this, first compute \[r(x_1) \triangleq \sum_{x_3} q_{13}(x_1,x_3)\] by summing over the $k$ values of $x_3$ for each of the $k$ possible values of $x_1$. Next, compute \[s(x_1) \triangleq \sum_{x_2} q_{12}(x_1,x_2)\] by summing over the $k$ values of $x_2$ for each of the $k$ possible values of $x_1$. The marginal distribution $p_1(x_1)$ is then equal to $r(x_1)\cdot s(x_1)$.

This procedure of iteratively summing out one variable at a time in order to compute the marginal distributions is known as variable elimination. Specifically, the variable elimination algorithm fixes an ordering of the random variables. The random variables are then summed out, one at a time, following this order. The order in which the variables are eliminated can have a significant impact on the running time of the variable elimination procedure in practice. To build intuition about this, consider the following example.

Creative Commons License