Implementing k-means in Scala

To generate sample data, I selected two points, (10, 20) and (25, 5), then generated a list of normally distributed points around those two – the exact points used are in the code below.

This implements Lloyd’s algorithm, which tries to cluster points in iterations in a simple manner:

1. Assume a certain number of clusters
2. Group the points at random
3. Compute the center of each cluster
4. For each point, compute which cluster is closest
5. Move all the points into new groupings
6. Repeat 3-5 a few times, until you’re happy with the results

I like how the functional programming style forces you to recreate all the data structures, in this case. It might be tempting to implement this in an imperative style, modifying data structures in place, but since steps 4-5 require separate data, you are protected against making it more difficult. You can see the full source below, or on github.

Since this example is fairly contrived, this converges pretty quickly:

Initial State: 
  Cluster 0
  Mean: (17.83517750970944, 12.242720407317105)
    (10.8348626966492, 18.7800980127523))
    (7.7875624720831, 20.1569764307574))
    (11.9096128931784, 21.1855674228972))
    (22.4668345067162, 8.9705504626857))
    (7.91362116378194, 21.325928219919))
    (22.636600400773, 2.46561420928429))
    (13.0838514816799, 20.3398794353494))
    (11.7396623802245, 17.7026240456956))
    (25.1439536911272, 3.58469981317611))
    (23.5359486724204, 4.07290025106778))
    (11.7493214262468, 17.8517235677469))
    (12.4277617893575, 19.4887691804508))
    (11.931275122466, 18.0462702532436))
    (25.4645673159779, 7.54703465191098))
    (21.8031183153743, 5.69297814349064))
    (23.9177161897547, 8.1377950229489))
    (24.5349708443852, 5.00561881333415))
    (26.2100410238973, 5.06220487544192))
    (23.7770902983858, 7.19445492687232))

  Cluster 1
  Mean: (16.95249500233747, 12.848199048608048)
    (11.7265904596619, 16.9636039793709))
    (10.7751248849735, 22.1517666115673))
    (23.6587920739353, 3.35476798095758))
    (21.4930923464916, 3.28999356823389))
    (26.4748241341303, 9.25128245838802))
    (7.03171204763376, 19.1985058633283))
    (23.7722765903534, 3.74873642284525))
    (10.259545802461, 23.4515683763173))
    (28.1587146197594, 3.70625885635717))
    (10.1057940183815, 18.7332929859685))
    (8.90149362263775, 19.6314465074203))
    (12.4353462881232, 19.6310467981989))
    (24.3793349065557, 4.59761596097384))
    (22.5447925324242, 2.99485404382734))
    (26.8942422516129, 5.02646862012427))
    (6.56491029696013, 21.5098251711267))
    (8.87507602702847, 21.4823134390704))
    (27.0339042858296, 4.4151109960116))
    (11.0118378554584, 20.9773232834654))

Iteration: 0
  Cluster 0
  Mean: (23.781370272978315, 5.754127202865132)
    (11.7265904596619, 16.9636039793709))
    (23.6587920739353, 3.35476798095758))
    (22.4668345067162, 8.9705504626857))
    (21.4930923464916, 3.28999356823389))
    (26.4748241341303, 9.25128245838802))
    (22.636600400773, 2.46561420928429))
    (23.7722765903534, 3.74873642284525))
    (25.1439536911272, 3.58469981317611))
    (28.1587146197594, 3.70625885635717))
    (23.5359486724204, 4.07290025106778))
    (24.3793349065557, 4.59761596097384))
    (25.4645673159779, 7.54703465191098))
    (22.5447925324242, 2.99485404382734))
    (21.8031183153743, 5.69297814349064))
    (26.8942422516129, 5.02646862012427))
    (23.9177161897547, 8.1377950229489))
    (24.5349708443852, 5.00561881333415))
    (26.2100410238973, 5.06220487544192))
    (27.0339042858296, 4.4151109960116))
    (23.7770902983858, 7.19445492687232))

  Cluster 1
  Mean: (10.296576237184727, 20.09138475584863)
    (10.8348626966492, 18.7800980127523))
    (7.7875624720831, 20.1569764307574))
    (10.7751248849735, 22.1517666115673))
    (11.9096128931784, 21.1855674228972))
    (7.91362116378194, 21.325928219919))
    (7.03171204763376, 19.1985058633283))
    (13.0838514816799, 20.3398794353494))
    (11.7396623802245, 17.7026240456956))
    (10.259545802461, 23.4515683763173))
    (10.1057940183815, 18.7332929859685))
    (11.7493214262468, 17.8517235677469))
    (8.90149362263775, 19.6314465074203))
    (12.4277617893575, 19.4887691804508))
    (12.4353462881232, 19.6310467981989))
    (11.931275122466, 18.0462702532436))
    (6.56491029696013, 21.5098251711267))
    (8.87507602702847, 21.4823134390704))
    (11.0118378554584, 20.9773232834654))

Iteration: 1
  Cluster 0
  Mean: (24.415832368416023, 5.164154740943777)
    (23.6587920739353, 3.35476798095758))
    (22.4668345067162, 8.9705504626857))
    (21.4930923464916, 3.28999356823389))
    (26.4748241341303, 9.25128245838802))
    (22.636600400773, 2.46561420928429))
    (23.7722765903534, 3.74873642284525))
    (25.1439536911272, 3.58469981317611))
    (28.1587146197594, 3.70625885635717))
    (23.5359486724204, 4.07290025106778))
    (24.3793349065557, 4.59761596097384))
    (25.4645673159779, 7.54703465191098))
    (22.5447925324242, 2.99485404382734))
    (21.8031183153743, 5.69297814349064))
    (26.8942422516129, 5.02646862012427))
    (23.9177161897547, 8.1377950229489))
    (24.5349708443852, 5.00561881333415))
    (26.2100410238973, 5.06220487544192))
    (27.0339042858296, 4.4151109960116))
    (23.7770902983858, 7.19445492687232))

  Cluster 1
  Mean: (10.371840143630894, 19.92676471498138)
    (10.8348626966492, 18.7800980127523))
    (11.7265904596619, 16.9636039793709))
    (7.7875624720831, 20.1569764307574))
    (10.7751248849735, 22.1517666115673))
    (11.9096128931784, 21.1855674228972))
    (7.91362116378194, 21.325928219919))
    (7.03171204763376, 19.1985058633283))
    (13.0838514816799, 20.3398794353494))
    (11.7396623802245, 17.7026240456956))
    (10.259545802461, 23.4515683763173))
    (10.1057940183815, 18.7332929859685))
    (11.7493214262468, 17.8517235677469))
    (8.90149362263775, 19.6314465074203))
    (12.4277617893575, 19.4887691804508))
    (12.4353462881232, 19.6310467981989))
    (11.931275122466, 18.0462702532436))
    (6.56491029696013, 21.5098251711267))
    (8.87507602702847, 21.4823134390704))
    (11.0118378554584, 20.9773232834654))

Iteration: 2
  Cluster 0
  Mean: (24.415832368416023, 5.164154740943777)
    (23.6587920739353, 3.35476798095758))
    (22.4668345067162, 8.9705504626857))
    (21.4930923464916, 3.28999356823389))
    (26.4748241341303, 9.25128245838802))
    (22.636600400773, 2.46561420928429))
    (23.7722765903534, 3.74873642284525))
    (25.1439536911272, 3.58469981317611))
    (28.1587146197594, 3.70625885635717))
    (23.5359486724204, 4.07290025106778))
    (24.3793349065557, 4.59761596097384))
    (25.4645673159779, 7.54703465191098))
    (22.5447925324242, 2.99485404382734))
    (21.8031183153743, 5.69297814349064))
    (26.8942422516129, 5.02646862012427))
    (23.9177161897547, 8.1377950229489))
    (24.5349708443852, 5.00561881333415))
    (26.2100410238973, 5.06220487544192))
    (27.0339042858296, 4.4151109960116))
    (23.7770902983858, 7.19445492687232))

  Cluster 1
  Mean: (10.371840143630894, 19.92676471498138)
    (10.8348626966492, 18.7800980127523))
    (11.7265904596619, 16.9636039793709))
    (7.7875624720831, 20.1569764307574))
    (10.7751248849735, 22.1517666115673))
    (11.9096128931784, 21.1855674228972))
    (7.91362116378194, 21.325928219919))
    (7.03171204763376, 19.1985058633283))
    (13.0838514816799, 20.3398794353494))
    (11.7396623802245, 17.7026240456956))
    (10.259545802461, 23.4515683763173))
    (10.1057940183815, 18.7332929859685))
    (11.7493214262468, 17.8517235677469))
    (8.90149362263775, 19.6314465074203))
    (12.4277617893575, 19.4887691804508))
    (12.4353462881232, 19.6310467981989))
    (11.931275122466, 18.0462702532436))
    (6.56491029696013, 21.5098251711267))
    (8.87507602702847, 21.4823134390704))
    (11.0118378554584, 20.9773232834654))

Iteration: 3
  Cluster 0
  Mean: (24.415832368416023, 5.164154740943777)
    (23.6587920739353, 3.35476798095758))
    (22.4668345067162, 8.9705504626857))
    (21.4930923464916, 3.28999356823389))
    (26.4748241341303, 9.25128245838802))
    (22.636600400773, 2.46561420928429))
    (23.7722765903534, 3.74873642284525))
    (25.1439536911272, 3.58469981317611))
    (28.1587146197594, 3.70625885635717))
    (23.5359486724204, 4.07290025106778))
    (24.3793349065557, 4.59761596097384))
    (25.4645673159779, 7.54703465191098))
    (22.5447925324242, 2.99485404382734))
    (21.8031183153743, 5.69297814349064))
    (26.8942422516129, 5.02646862012427))
    (23.9177161897547, 8.1377950229489))
    (24.5349708443852, 5.00561881333415))
    (26.2100410238973, 5.06220487544192))
    (27.0339042858296, 4.4151109960116))
    (23.7770902983858, 7.19445492687232))

  Cluster 1
  Mean: (10.371840143630894, 19.92676471498138)
    (10.8348626966492, 18.7800980127523))
    (11.7265904596619, 16.9636039793709))
    (7.7875624720831, 20.1569764307574))
    (10.7751248849735, 22.1517666115673))
    (11.9096128931784, 21.1855674228972))
    (7.91362116378194, 21.325928219919))
    (7.03171204763376, 19.1985058633283))
    (13.0838514816799, 20.3398794353494))
    (11.7396623802245, 17.7026240456956))
    (10.259545802461, 23.4515683763173))
    (10.1057940183815, 18.7332929859685))
    (11.7493214262468, 17.8517235677469))
    (8.90149362263775, 19.6314465074203))
    (12.4277617893575, 19.4887691804508))
    (12.4353462881232, 19.6310467981989))
    (11.931275122466, 18.0462702532436))
    (6.56491029696013, 21.5098251711267))
    (8.87507602702847, 21.4823134390704))
    (11.0118378554584, 20.9773232834654))

Iteration: 4
  Cluster 0
  Mean: (24.415832368416023, 5.164154740943777)
    (23.6587920739353, 3.35476798095758))
    (22.4668345067162, 8.9705504626857))
    (21.4930923464916, 3.28999356823389))
    (26.4748241341303, 9.25128245838802))
    (22.636600400773, 2.46561420928429))
    (23.7722765903534, 3.74873642284525))
    (25.1439536911272, 3.58469981317611))
    (28.1587146197594, 3.70625885635717))
    (23.5359486724204, 4.07290025106778))
    (24.3793349065557, 4.59761596097384))
    (25.4645673159779, 7.54703465191098))
    (22.5447925324242, 2.99485404382734))
    (21.8031183153743, 5.69297814349064))
    (26.8942422516129, 5.02646862012427))
    (23.9177161897547, 8.1377950229489))
    (24.5349708443852, 5.00561881333415))
    (26.2100410238973, 5.06220487544192))
    (27.0339042858296, 4.4151109960116))
    (23.7770902983858, 7.19445492687232))

  Cluster 1
  Mean: (10.371840143630894, 19.92676471498138)
    (10.8348626966492, 18.7800980127523))
    (11.7265904596619, 16.9636039793709))
    (7.7875624720831, 20.1569764307574))
    (10.7751248849735, 22.1517666115673))
    (11.9096128931784, 21.1855674228972))
    (7.91362116378194, 21.325928219919))
    (7.03171204763376, 19.1985058633283))
    (13.0838514816799, 20.3398794353494))
    (11.7396623802245, 17.7026240456956))
    (10.259545802461, 23.4515683763173))
    (10.1057940183815, 18.7332929859685))
    (11.7493214262468, 17.8517235677469))
    (8.90149362263775, 19.6314465074203))
    (12.4277617893575, 19.4887691804508))
    (12.4353462881232, 19.6310467981989))
    (11.931275122466, 18.0462702532436))
    (6.56491029696013, 21.5098251711267))
    (8.87507602702847, 21.4823134390704))
    (11.0118378554584, 20.9773232834654))

Iteration: 5
  Cluster 0
  Mean: (24.415832368416023, 5.164154740943777)
    (23.6587920739353, 3.35476798095758))
    (22.4668345067162, 8.9705504626857))
    (21.4930923464916, 3.28999356823389))
    (26.4748241341303, 9.25128245838802))
    (22.636600400773, 2.46561420928429))
    (23.7722765903534, 3.74873642284525))
    (25.1439536911272, 3.58469981317611))
    (28.1587146197594, 3.70625885635717))
    (23.5359486724204, 4.07290025106778))
    (24.3793349065557, 4.59761596097384))
    (25.4645673159779, 7.54703465191098))
    (22.5447925324242, 2.99485404382734))
    (21.8031183153743, 5.69297814349064))
    (26.8942422516129, 5.02646862012427))
    (23.9177161897547, 8.1377950229489))
    (24.5349708443852, 5.00561881333415))
    (26.2100410238973, 5.06220487544192))
    (27.0339042858296, 4.4151109960116))
    (23.7770902983858, 7.19445492687232))

  Cluster 1
  Mean: (10.371840143630894, 19.92676471498138)
    (10.8348626966492, 18.7800980127523))
    (11.7265904596619, 16.9636039793709))
    (7.7875624720831, 20.1569764307574))
    (10.7751248849735, 22.1517666115673))
    (11.9096128931784, 21.1855674228972))
    (7.91362116378194, 21.325928219919))
    (7.03171204763376, 19.1985058633283))
    (13.0838514816799, 20.3398794353494))
    (11.7396623802245, 17.7026240456956))
    (10.259545802461, 23.4515683763173))
    (10.1057940183815, 18.7332929859685))
    (11.7493214262468, 17.8517235677469))
    (8.90149362263775, 19.6314465074203))
    (12.4277617893575, 19.4887691804508))
    (12.4353462881232, 19.6310467981989))
    (11.931275122466, 18.0462702532436))
    (6.56491029696013, 21.5098251711267))
    (8.87507602702847, 21.4823134390704))
    (11.0118378554584, 20.9773232834654))
class Point(dx: Double, dy: Double) {
  val x: Double = dx
  val y: Double = dy
 
  override def toString(): String = {
    "(" + x + ", " + y + ")"
  }
 
  def dist(p: Point): Double = {
    return (x - p.x) * (x - p.x) + (y - p.y) * (y - p.y);
  }
}
 
object kmeans extends App {
  val k: Int = 2
 
  // Correct answers to centers are (10, 20) and (25, 5)
  val points: List[Point] = List(
    new Point(10.8348626966492, 18.7800980127523),
    new Point(10.259545802461, 23.4515683763173),
    new Point(11.7396623802245, 17.7026240456956),
    new Point(12.4277617893575, 19.4887691804508),
    new Point(10.1057940183815, 18.7332929859685),
    new Point(11.0118378554584, 20.9773232834654),
    new Point(7.03171204763376, 19.1985058633283),
    new Point(6.56491029696013, 21.5098251711267),
    new Point(10.7751248849735, 22.1517666115673),
    new Point(8.90149362263775, 19.6314465074203),
    new Point(11.931275122466, 18.0462702532436),
    new Point(11.7265904596619, 16.9636039793709),
    new Point(11.7493214262468, 17.8517235677469),
    new Point(12.4353462881232, 19.6310467981989),
    new Point(13.0838514816799, 20.3398794353494),
    new Point(7.7875624720831, 20.1569764307574),
    new Point(11.9096128931784, 21.1855674228972),
    new Point(8.87507602702847, 21.4823134390704),
    new Point(7.91362116378194, 21.325928219919),
    new Point(26.4748241341303, 9.25128245838802),
    new Point(26.2100410238973, 5.06220487544192),
    new Point(28.1587146197594, 3.70625885635717),
    new Point(26.8942422516129, 5.02646862012427),
    new Point(23.7770902983858, 7.19445492687232),
    new Point(23.6587920739353, 3.35476798095758),
    new Point(23.7722765903534, 3.74873642284525),
    new Point(23.9177161897547, 8.1377950229489),
    new Point(22.4668345067162, 8.9705504626857),
    new Point(24.5349708443852, 5.00561881333415),
    new Point(24.3793349065557, 4.59761596097384),
    new Point(27.0339042858296, 4.4151109960116),
    new Point(21.8031183153743, 5.69297814349064),
    new Point(22.636600400773, 2.46561420928429),
    new Point(25.1439536911272, 3.58469981317611),
    new Point(21.4930923464916, 3.28999356823389),
    new Point(23.5359486724204, 4.07290025106778),
    new Point(22.5447925324242, 2.99485404382734),
    new Point(25.4645673159779, 7.54703465191098)).sortBy(
      p => (p.x + " " + p.y).hashCode())
 
  def clusterMean(points: List[Point]): Point = {
    val cumulative = points.reduceLeft((a: Point, b: Point) => new Point(a.x + b.x, a.y + b.y))
 
    return new Point(cumulative.x / points.length, cumulative.y / points.length)
  }
 
  def render(points: Map[Int, List[Point]]) {
    for (clusterNumber  x._2 % k) transform (
        (i: Int, p: List[(Point, Int)]) => for (x  clusters.map(cluster => cluster._1)
 
    // find cluster means
    val means =
      (clusters: Map[Int, List[Point]]) =>
        for (clusterIndex  closest(p, means(clusters)))
 
    render(newClusters)
 
    return newClusters
  }
 
  var clusterToTest = clusters
  for (i

Tags: , , , , , , , , ,

4 comments ↓

#1 network_graph on 05.04.13 at 5:01 pm

Interesting entry. Just some remarks:

You can write `case class Point(x: Double, y: Double)` so you don’t have to write public getters for the coordinates; then you can leave out the `new` keyword, e.g. `Point(10.835, 18.780)`.

There are a few more ‘javaish’ things, e.g. you should omit the `return` keyword, and you can write `println` instead of `System.out.println`.

#2 Gary on 05.05.13 at 2:36 am

Great, thanks. I’m still finding my way around with Scala a bit.

#3 Andrew McNaughton on 01.05.14 at 3:00 am

in Class “Point”, your “dist” function is not euclidian distance. Is that intended?

What you have there amounts to the square of the cosine of the vectors from the origin to each points, which seems rather odd.

#4 Gary on 02.01.14 at 3:13 pm

Wow, you’re absolutely right, I updated the post to correct this.

Leave a Comment

Current ye@r *