An einsum use-case

(Source)

Recently I've been working with some of my students on a paper for the SciPy Conference related to our OrNet project. Part of this involved designing our own Gaussian Mixture Model (GMM). In doing so, I somewhat accidentally-on-purpose stumbled upon the beauty of einsum. So I wanted to briefly motivate its use (expanding upon the short use-case I wrote on my research group's blog) and show the performance results of re-implementing iterative computations using einsum.

Gaussian Madness

One of the critical steps of this algorithm is to evaluate the probability of any given point $\vec{x}$ under the model that we've learned so far (the current values of $\vec{\mu}$ and $\Sigma$. In higher dimensions (as in, higher than 1), this has a pretty tricky formula:

$$ P(\vec{x} | \vec{\mu}, \Sigma) = \frac{\exp{(-\frac{1}{2}(\vec{x} - \vec{\mu})^T \Sigma^{-1} (\vec{x} - \vec{\mu}) )}} {\sqrt{(2 \pi)^2 \det{(\Sigma)} }} $$

It's...messy. The things that make it particularly tough to efficiently evaluate include:

  • the matrix multiplication in the numerator
  • the inversion of $\Sigma$ in the numerator
  • the determinant of $\Sigma$ in the denominator
  • the fact that this has to be computed for every $\vec{x}$ (thousands), under every combination of $\mu$ and $\Sigma$ (often a few dozen), at every iteration of the algorithm (often a few dozen)

But wait! That's not all!

Also critical to the GMM are the update steps: re-estimating the current values of the parameters so they (hopefully) better fit your data. This happens a lot!

In multidimensional wonderlands, the updates of the parameters look something like this:

$$ \vec{\mu}_{ki} = \frac{1}{N_k} \sum_{n = 1}^N r_{kn} \vec{x}_n $$

$$ \Sigma_{ki} = \frac{1}{N_k} \sum_{n = 1}^N r_{kn} (\vec{x}_n - \vec{\mu}_{ki})(\vec{x}_n - \vec{\mu}_{ki})^T $$

The first equation for $\vec{\mu}_{ki}$--the $k^{th}$ component (one of the 50-100) at the $i^{th}$ iteration (often at least a few dozen)--is a "fairly" easy multiplication of each data point $\vec{x}_n$ by some number $r_{kn}$, all summed up and scaled by $\frac{1}{N_k}$.

The second equation is a lot harder.

For every $\vec{\mu}_{ki}$ we have a corresponding $\Sigma_{ki}$. Which, remember, is a full matrix. This involves taking the $\vec{\mu}_{ki}$ we computed in the first equation, and computing outer products with every data point $\vec{x}_n$.

"Outer product" is a fancy name for "matrix multiplication"!

Bring on the Code

In [1]:
import numpy as np
import scipy.linalg as sla

So of course, I initially implemented these two core operations in the way that made the most sense: using numpy functions and Python comprehensions.

For computing $P(\vec{x})$:

In [2]:
def px1(X, mu, sigma):
    det = sla.det(sigma)
    inv = sla.inv(sigma)
    p = np.array([(x - mu).dot(inv).dot(x - mu) for x in X])
    n = 1 / ((((2 * np.pi) ** 2) * det) ** 0.5)
    px = np.exp(-0.5 * p) * n
    return px

And for re-computing $\mu$ and $\Sigma$:

In [3]:
def musig1(X, r):
    N = r.sum()
    mu_next = (r[:, np.newaxis] * X).sum(axis = 0) / N
    
    s = np.zeros(shape = (X.shape[1], X.shape[1]))
    for x, ri in zip(X - mu_next, r):
        s += (np.outer(x, x) * ri)
    sigma_next = s / N

    return mu_next, sigma_next

These were computational bears. In px1, the line starting p = ... is a huge computational cost: doing two dot products inside a loop over all the data points $\vec{x}_n$ is really, really expensive.

And in musig1, the operation to compute the next value of $\Sigma$ involves computing the outer product of every $\vec{x}_n$ subtracted by the current value of $\vec{\mu}$.

Suffice to say, when it came time to actually run the algorithm, it shouldn't come as a surprise that, with a dataset of roughly 130,000 points in 2 dimensions, each iteration of our algorithm took about 52 seconds.

...the vast majority of which was spent computing dot products!

before

More than half of a single iteration was devoted to dot products in the px1 function above. That's bad.

Just for grins, let's time everything:

In [4]:
X = np.random.random(size = (130000, 2))  # To simulate the data I was working with.
mu = X.mean(axis = 0)  # Compute an actual mean.
sigma = np.array([[5, 0], [0, 5]])
r = np.random.random(130000)
In [5]:
%timeit px1(X, mu, sigma)
301 ms ± 6.99 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [6]:
%timeit musig1(X, r)
847 ms ± 5.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

That's a full second of computation just for 1 of the Gaussian components (i.e., one pair of $\mu$ and $\Sigma$). Crank that number of components up to 50 and, well, there's your ~50-second runtime.

I was admittedly at a bit of a loss here, considering I was already using NumPy constructs. But at the same time, these were highly structured linear algebra operations; there had to be something more efficient?

einsum to the Rescue

DISCLAIMER: I should highlight that I didn't immediately land on this. It took some Googling around and some sitting and thinking before I figured out how this could be used. Then it took some more trial-and-error to actually get it to work.

But work, it did. Holy crap.

np.einsum is an implementation of the Einstein Summation, which is a super-fancy term for a cool (but admittedly tricky-to-grasp) shorthand notation for dealing with multiplications and summations of vectors, matrices, and tensors.

I won't rehash the basic examples I've found across the web that helped me understand how this works, but I will give special emphasis to two articles:

  1. Einstein Summation in NumPy, and
  2. A basic introduction to NumPy's einsum

I found #2 particularly accessible, but I'd recommend both.

It took some doing, but I managed to rewrite both equations using einsums:

In [7]:
def px2(X, mu, sigma):
    det = sla.det(sigma)
    inv = sla.inv(sigma)

    ###
    p = np.einsum('ni,ji,ni->n', X - mu, inv, X - mu)
    ###

    n = 1 / ((((2 * np.pi) ** 2) * det) ** 0.5)
    px = np.exp(-0.5 * p) * n
    return px
In [8]:
def musig2(X, r):
    N = r.sum()
    mu_next = (r[:, np.newaxis] * X).sum(axis = 0) / N

    ###
    sigma_next = np.einsum("ni,nj->ij", (X - mu_next) * r[:, np.newaxis], X - mu_next) / N
    ###

    return mu_next, sigma_next

The musig2 use actually ended up being a straightforward application of the "each row of $A$ multiplied by $B$" rule at the end of the #2 link above, with the slight wrinkle of making the first axis, n, the one that is summed over in the result.

The one in px2 was a lot harder, and for a time I wasn't convinced it could be done. At one point I was able to basically split the operation in two einsums that were more or less identical to the one in musig2; it basically involved a repeated application of "multiply each row of $A$ by $B$". But I kept hunting around the web, particularly for examples of multiplying many matrices together, and eventually managed to figure this one out.

It boiled down to basics of how einsum works: I knew the output had to be a list of numbers n long, and that the first and third inputs were identical (ni). The key was: if everything had 2 axes going in, but only 1 axis was coming out, then all three inputs needed to share one axis that could be summed over.

With great power blah blah blah

Just for fun, let's run our examples again:

In [9]:
%timeit px2(X, mu, sigma)
6.82 ms ± 32.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [10]:
%timeit musig2(X, r)
7.37 ms ± 260 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

To sum up:

  • Going from a list comprehension of calls to numpy.dot to a single einsum decreased runtime of a single iteration from 300ms to 7ms, a speed-up of over 40x.
  • Going from a loop of numpy.outer products to a single einsum decreased runtime of a single iteration from 847ms to 7.5ms, a speed-up of over 110x.

I re-ran cProfile to see what my performance gains were, and holy crap Batman.

after

Not only was my total runtime down from ~52 seconds per iteration to a mere 1.5 seconds per iteration, but the most time-consuming block of my code was now einsum itself!

Of course, I wasn't going to argue with a full order of magnitude reduction in runtime.

Our SciPy paper is still under review, so feel free to head over there and check out how we're putting GMMs to use for bioimaging!

Comments