K means

The k-means clustering algorithm with Python and J implementations.

K Means

The idea of this algorithm is that you have some data which can be "clustered" into groups based on some distance metric.

  1. choose a number for k (how many clusters)
  2. start with a random guess for the centroid of each cluster
  3. using these centroids, find the points which are closest to each centroid
  4. for each group of points, calculate a new centroid
  5. goto 3 until convergence

Example (1d)

  • k: 3
  • initial centroid locations: 5 15 25
  • data points: 1 3 9 11 12 37 43 45 60

J solution:

0j1":{{p(+/%#)/.~{.@/:"1|p-"0 _ y}}^:1[5 15 25[p=:1 3 9 11 12 37 43 45 60

This Python solution is an approximate translation of the J solution. Both take the absolute difference between (all) the data points and each individual centroid, then find the index of the nearest centroid using this difference, and find the mean all the points closest to that centroid. This mean becomes the new centroid for that group of points.

import numpy as np
p = np.array([1,3,9,11,12,37,43,45,60])
c = np.array([5,15,25])
def foo(c,p):
 m=np.argmin(np.abs(p-c[:,None]),axis=0)
 x=[np.mean([p[j]for j in range(len(p))if m[j]==i])for i in range(len(c))]
 return np.array(x)
print(np.round(foo(c,p),1))

Discussion

Obviously the J version is shorter, but that doesn't mean it's better. However, to me (a person who knows what the J symbols mean), the J version more clearly expresses the algorithm.

Here's what I mean by that.

In J, function arguments have default names y (if there's one argument) or x and y (if there are two). Within the function I wrote, this part executes first:

|p-"0 _ y

This is equivalent to the following Python:

np.abs(p-c[:,None])

Merely replacing the symbol for abs (|) with the word abs doesn't tell the whole story. What's really interesting is that the - function is modified to apply at a non-default rank. In this case, its left rank is 0 and its right rank is infinite (spelled _, and meaning "equal to the highest argument rank"). For vectors x and y, x -"(0 _) y implies subtracting the entire vector y from each element of x.

This illustrates a key difference between J and Python's Numpy library; in J, all functions have a rank (which the user may override), whereas in Numpy, functions may be broadcast onto the elements of an argument. This is why I used c[:,None] to add a leading axis to c, transforming it into a 1-column matrix rather than a vector.

Essentially, in Numpy, you need to modify the data to fit the broadcasting rules. But in J, you modify the function to fit the data.

Next up in the J explanation is this snippet (and from here I'll use ... to indicate "the previously explained expression"):

{.@/:"1 ...

This is equivalent to np.argsort(..., axis=0)[0], but I used np.argmin(..., axis=0) in the Python version since it is more conventional.

In both J and Python, this "argmin" operation returns a vector that looks essentially like this:

0 0 0 1 1 2 2 2 2

Last we have this section:

p(+/%#)/.~ ...

This uses key (/.) which is like "group by" in SQL, but with a twist: rather than solely forming groups, it applies a function to the groups. In this case, the function is arithmetic mean (+/%#). I think of this like "find groups in … and get the mean of each group".

This is roughly equivalent to the following Python:

[np.mean([p[j]for j in range(len(p))if m[j]==i])for i in range(len(c))]

Since learning about /., I think it expresses the algorithm much more clearly. Maybe this is an unfair comparison, because the above Python uses two nested list comprehensions, which is probably considered bad style.

Here's a translation using regular for loops:

ls = []
for i in range(len(c)):
    ls.append([])
    for j in range(len(p)):
        if m[j] == i:
            ls[i].append(p[j])

    ls[i] = np.mean(ls[i])

This version has a low amount of complexity per line. Maybe that makes it easier to understand. However, it takes up a lot more space, and it seems like the majority of the characters are devoted to moving through the structures, and only a tiny fraction deal with the core of the algorithm. To me, this feels like the operations (and the data for that matter) are visually less important than the control flow.

Currently, I prefer p(+/%#)/.~.

One More Thing

So far, this has been dealing with a single iteration of the k-means algorithm. But usually, we want to run multiple iterations, updating the means each time to hopefully converge on a good solution.

Here's a Python version which iterates until the result repeats itself:

import numpy as np
p = np.array([1,3,9,11,12,37,43,45,60])
c = np.array([5,15,25])
def foo(c,p):
 m=np.argmin(np.abs(p-c[:,None]),axis=0)
 x=[np.mean([p[j]for j in range(len(p))if m[j]==i])for i in range(len(c))]
 return np.array(x)
while True:
 cnew = foo(c,p)
 if np.allclose(cnew,c):
  break
 c = cnew
print(np.round(cnew,1))

I added a while loop, another numpy function (np.allclose), and another variable.

And finally, here's the J version that iterates until convergence:

0j1":{{p(+/%#)/.~{.@/:"1|p-"0 _ y}}^:_[5 15 25[p=:1 3 9 11 12 37 43 45 60

That's a 1-character diff, so I'll highlight it here:

- ^:1
+ ^:_

Instead of "do the function once", this means "do the function until it stops changing". This showcases what J is all about: finding essential abstractions that compose well, and writing them concisely.