Implementing k-means in Scala

To generate sample data, I selected two points, (10, 20) and (25, 5), then generated a list of normally distributed points around those two – the exact points used are in the code below.

This implements Lloyd’s algorithm, which tries to cluster points in iterations in a simple manner:

1. Assume a certain number of clusters
2. Group the points at random
3. Compute the center of each cluster
4. For each point, compute which cluster is closest
5. Move all the points into new groupings
6. Repeat 3-5 a few times, until you’re happy with the results

I like how the functional programming style forces you to recreate all the data structures, in this case. It might be tempting to implement this in an imperative style, modifying data structures in place, but since steps 4-5 require separate data, you are protected against making it more difficult. You can see the full source below, or on github.

Since this example is fairly contrived, this converges pretty quickly:

Initial State: 
  Cluster 0
  Mean: (17.83517750970944, 12.242720407317105)
    (10.8348626966492, 18.7800980127523))
    (7.7875624720831, 20.1569764307574))
    (11.9096128931784, 21.1855674228972))
    (22.4668345067162, 8.9705504626857))
    (7.91362116378194, 21.325928219919))
    (22.636600400773, 2.46561420928429))
    (13.0838514816799, 20.3398794353494))
    (11.7396623802245, 17.7026240456956))
    (25.1439536911272, 3.58469981317611))
    (23.5359486724204, 4.07290025106778))
    (11.7493214262468, 17.8517235677469))
    (12.4277617893575, 19.4887691804508))
    (11.931275122466, 18.0462702532436))
    (25.4645673159779, 7.54703465191098))
    (21.8031183153743, 5.69297814349064))
    (23.9177161897547, 8.1377950229489))
    (24.5349708443852, 5.00561881333415))
    (26.2100410238973, 5.06220487544192))
    (23.7770902983858, 7.19445492687232))

  Cluster 1
  Mean: (16.95249500233747, 12.848199048608048)
    (11.7265904596619, 16.9636039793709))
    (10.7751248849735, 22.1517666115673))
    (23.6587920739353, 3.35476798095758))
    (21.4930923464916, 3.28999356823389))
    (26.4748241341303, 9.25128245838802))
    (7.03171204763376, 19.1985058633283))
    (23.7722765903534, 3.74873642284525))
    (10.259545802461, 23.4515683763173))
    (28.1587146197594, 3.70625885635717))
    (10.1057940183815, 18.7332929859685))
    (8.90149362263775, 19.6314465074203))
    (12.4353462881232, 19.6310467981989))
    (24.3793349065557, 4.59761596097384))
    (22.5447925324242, 2.99485404382734))
    (26.8942422516129, 5.02646862012427))
    (6.56491029696013, 21.5098251711267))
    (8.87507602702847, 21.4823134390704))
    (27.0339042858296, 4.4151109960116))
    (11.0118378554584, 20.9773232834654))

Iteration: 0
  Cluster 0
  Mean: (23.781370272978315, 5.754127202865132)
    (11.7265904596619, 16.9636039793709))
    (23.6587920739353, 3.35476798095758))
    (22.4668345067162, 8.9705504626857))
    (21.4930923464916, 3.28999356823389))
    (26.4748241341303, 9.25128245838802))
    (22.636600400773, 2.46561420928429))
    (23.7722765903534, 3.74873642284525))
    (25.1439536911272, 3.58469981317611))
    (28.1587146197594, 3.70625885635717))
    (23.5359486724204, 4.07290025106778))
    (24.3793349065557, 4.59761596097384))
    (25.4645673159779, 7.54703465191098))
    (22.5447925324242, 2.99485404382734))
    (21.8031183153743, 5.69297814349064))
    (26.8942422516129, 5.02646862012427))
    (23.9177161897547, 8.1377950229489))
    (24.5349708443852, 5.00561881333415))
    (26.2100410238973, 5.06220487544192))
    (27.0339042858296, 4.4151109960116))
    (23.7770902983858, 7.19445492687232))

  Cluster 1
  Mean: (10.296576237184727, 20.09138475584863)
    (10.8348626966492, 18.7800980127523))
    (7.7875624720831, 20.1569764307574))
    (10.7751248849735, 22.1517666115673))
    (11.9096128931784, 21.1855674228972))
    (7.91362116378194, 21.325928219919))
    (7.03171204763376, 19.1985058633283))
    (13.0838514816799, 20.3398794353494))
    (11.7396623802245, 17.7026240456956))
    (10.259545802461, 23.4515683763173))
    (10.1057940183815, 18.7332929859685))
    (11.7493214262468, 17.8517235677469))
    (8.90149362263775, 19.6314465074203))
    (12.4277617893575, 19.4887691804508))
    (12.4353462881232, 19.6310467981989))
    (11.931275122466, 18.0462702532436))
    (6.56491029696013, 21.5098251711267))
    (8.87507602702847, 21.4823134390704))
    (11.0118378554584, 20.9773232834654))

Iteration: 1
  Cluster 0
  Mean: (24.415832368416023, 5.164154740943777)
    (23.6587920739353, 3.35476798095758))
    (22.4668345067162, 8.9705504626857))
    (21.4930923464916, 3.28999356823389))
    (26.4748241341303, 9.25128245838802))
    (22.636600400773, 2.46561420928429))
    (23.7722765903534, 3.74873642284525))
    (25.1439536911272, 3.58469981317611))
    (28.1587146197594, 3.70625885635717))
    (23.5359486724204, 4.07290025106778))
    (24.3793349065557, 4.59761596097384))
    (25.4645673159779, 7.54703465191098))
    (22.5447925324242, 2.99485404382734))
    (21.8031183153743, 5.69297814349064))
    (26.8942422516129, 5.02646862012427))
    (23.9177161897547, 8.1377950229489))
    (24.5349708443852, 5.00561881333415))
    (26.2100410238973, 5.06220487544192))
    (27.0339042858296, 4.4151109960116))
    (23.7770902983858, 7.19445492687232))

  Cluster 1
  Mean: (10.371840143630894, 19.92676471498138)
    (10.8348626966492, 18.7800980127523))
    (11.7265904596619, 16.9636039793709))
    (7.7875624720831, 20.1569764307574))
    (10.7751248849735, 22.1517666115673))
    (11.9096128931784, 21.1855674228972))
    (7.91362116378194, 21.325928219919))
    (7.03171204763376, 19.1985058633283))
    (13.0838514816799, 20.3398794353494))
    (11.7396623802245, 17.7026240456956))
    (10.259545802461, 23.4515683763173))
    (10.1057940183815, 18.7332929859685))
    (11.7493214262468, 17.8517235677469))
    (8.90149362263775, 19.6314465074203))
    (12.4277617893575, 19.4887691804508))
    (12.4353462881232, 19.6310467981989))
    (11.931275122466, 18.0462702532436))
    (6.56491029696013, 21.5098251711267))
    (8.87507602702847, 21.4823134390704))
    (11.0118378554584, 20.9773232834654))

Iteration: 2
  Cluster 0
  Mean: (24.415832368416023, 5.164154740943777)
    (23.6587920739353, 3.35476798095758))
    (22.4668345067162, 8.9705504626857))
    (21.4930923464916, 3.28999356823389))
    (26.4748241341303, 9.25128245838802))
    (22.636600400773, 2.46561420928429))
    (23.7722765903534, 3.74873642284525))
    (25.1439536911272, 3.58469981317611))
    (28.1587146197594, 3.70625885635717))
    (23.5359486724204, 4.07290025106778))
    (24.3793349065557, 4.59761596097384))
    (25.4645673159779, 7.54703465191098))
    (22.5447925324242, 2.99485404382734))
    (21.8031183153743, 5.69297814349064))
    (26.8942422516129, 5.02646862012427))
    (23.9177161897547, 8.1377950229489))
    (24.5349708443852, 5.00561881333415))
    (26.2100410238973, 5.06220487544192))
    (27.0339042858296, 4.4151109960116))
    (23.7770902983858, 7.19445492687232))

  Cluster 1
  Mean: (10.371840143630894, 19.92676471498138)
    (10.8348626966492, 18.7800980127523))
    (11.7265904596619, 16.9636039793709))
    (7.7875624720831, 20.1569764307574))
    (10.7751248849735, 22.1517666115673))
    (11.9096128931784, 21.1855674228972))
    (7.91362116378194, 21.325928219919))
    (7.03171204763376, 19.1985058633283))
    (13.0838514816799, 20.3398794353494))
    (11.7396623802245, 17.7026240456956))
    (10.259545802461, 23.4515683763173))
    (10.1057940183815, 18.7332929859685))
    (11.7493214262468, 17.8517235677469))
    (8.90149362263775, 19.6314465074203))
    (12.4277617893575, 19.4887691804508))
    (12.4353462881232, 19.6310467981989))
    (11.931275122466, 18.0462702532436))
    (6.56491029696013, 21.5098251711267))
    (8.87507602702847, 21.4823134390704))
    (11.0118378554584, 20.9773232834654))

Iteration: 3
  Cluster 0
  Mean: (24.415832368416023, 5.164154740943777)
    (23.6587920739353, 3.35476798095758))
    (22.4668345067162, 8.9705504626857))
    (21.4930923464916, 3.28999356823389))
    (26.4748241341303, 9.25128245838802))
    (22.636600400773, 2.46561420928429))
    (23.7722765903534, 3.74873642284525))
    (25.1439536911272, 3.58469981317611))
    (28.1587146197594, 3.70625885635717))
    (23.5359486724204, 4.07290025106778))
    (24.3793349065557, 4.59761596097384))
    (25.4645673159779, 7.54703465191098))
    (22.5447925324242, 2.99485404382734))
    (21.8031183153743, 5.69297814349064))
    (26.8942422516129, 5.02646862012427))
    (23.9177161897547, 8.1377950229489))
    (24.5349708443852, 5.00561881333415))
    (26.2100410238973, 5.06220487544192))
    (27.0339042858296, 4.4151109960116))
    (23.7770902983858, 7.19445492687232))

  Cluster 1
  Mean: (10.371840143630894, 19.92676471498138)
    (10.8348626966492, 18.7800980127523))
    (11.7265904596619, 16.9636039793709))
    (7.7875624720831, 20.1569764307574))
    (10.7751248849735, 22.1517666115673))
    (11.9096128931784, 21.1855674228972))
    (7.91362116378194, 21.325928219919))
    (7.03171204763376, 19.1985058633283))
    (13.0838514816799, 20.3398794353494))
    (11.7396623802245, 17.7026240456956))
    (10.259545802461, 23.4515683763173))
    (10.1057940183815, 18.7332929859685))
    (11.7493214262468, 17.8517235677469))
    (8.90149362263775, 19.6314465074203))
    (12.4277617893575, 19.4887691804508))
    (12.4353462881232, 19.6310467981989))
    (11.931275122466, 18.0462702532436))
    (6.56491029696013, 21.5098251711267))
    (8.87507602702847, 21.4823134390704))
    (11.0118378554584, 20.9773232834654))

Iteration: 4
  Cluster 0
  Mean: (24.415832368416023, 5.164154740943777)
    (23.6587920739353, 3.35476798095758))
    (22.4668345067162, 8.9705504626857))
    (21.4930923464916, 3.28999356823389))
    (26.4748241341303, 9.25128245838802))
    (22.636600400773, 2.46561420928429))
    (23.7722765903534, 3.74873642284525))
    (25.1439536911272, 3.58469981317611))
    (28.1587146197594, 3.70625885635717))
    (23.5359486724204, 4.07290025106778))
    (24.3793349065557, 4.59761596097384))
    (25.4645673159779, 7.54703465191098))
    (22.5447925324242, 2.99485404382734))
    (21.8031183153743, 5.69297814349064))
    (26.8942422516129, 5.02646862012427))
    (23.9177161897547, 8.1377950229489))
    (24.5349708443852, 5.00561881333415))
    (26.2100410238973, 5.06220487544192))
    (27.0339042858296, 4.4151109960116))
    (23.7770902983858, 7.19445492687232))

  Cluster 1
  Mean: (10.371840143630894, 19.92676471498138)
    (10.8348626966492, 18.7800980127523))
    (11.7265904596619, 16.9636039793709))
    (7.7875624720831, 20.1569764307574))
    (10.7751248849735, 22.1517666115673))
    (11.9096128931784, 21.1855674228972))
    (7.91362116378194, 21.325928219919))
    (7.03171204763376, 19.1985058633283))
    (13.0838514816799, 20.3398794353494))
    (11.7396623802245, 17.7026240456956))
    (10.259545802461, 23.4515683763173))
    (10.1057940183815, 18.7332929859685))
    (11.7493214262468, 17.8517235677469))
    (8.90149362263775, 19.6314465074203))
    (12.4277617893575, 19.4887691804508))
    (12.4353462881232, 19.6310467981989))
    (11.931275122466, 18.0462702532436))
    (6.56491029696013, 21.5098251711267))
    (8.87507602702847, 21.4823134390704))
    (11.0118378554584, 20.9773232834654))

Iteration: 5
  Cluster 0
  Mean: (24.415832368416023, 5.164154740943777)
    (23.6587920739353, 3.35476798095758))
    (22.4668345067162, 8.9705504626857))
    (21.4930923464916, 3.28999356823389))
    (26.4748241341303, 9.25128245838802))
    (22.636600400773, 2.46561420928429))
    (23.7722765903534, 3.74873642284525))
    (25.1439536911272, 3.58469981317611))
    (28.1587146197594, 3.70625885635717))
    (23.5359486724204, 4.07290025106778))
    (24.3793349065557, 4.59761596097384))
    (25.4645673159779, 7.54703465191098))
    (22.5447925324242, 2.99485404382734))
    (21.8031183153743, 5.69297814349064))
    (26.8942422516129, 5.02646862012427))
    (23.9177161897547, 8.1377950229489))
    (24.5349708443852, 5.00561881333415))
    (26.2100410238973, 5.06220487544192))
    (27.0339042858296, 4.4151109960116))
    (23.7770902983858, 7.19445492687232))

  Cluster 1
  Mean: (10.371840143630894, 19.92676471498138)
    (10.8348626966492, 18.7800980127523))
    (11.7265904596619, 16.9636039793709))
    (7.7875624720831, 20.1569764307574))
    (10.7751248849735, 22.1517666115673))
    (11.9096128931784, 21.1855674228972))
    (7.91362116378194, 21.325928219919))
    (7.03171204763376, 19.1985058633283))
    (13.0838514816799, 20.3398794353494))
    (11.7396623802245, 17.7026240456956))
    (10.259545802461, 23.4515683763173))
    (10.1057940183815, 18.7332929859685))
    (11.7493214262468, 17.8517235677469))
    (8.90149362263775, 19.6314465074203))
    (12.4277617893575, 19.4887691804508))
    (12.4353462881232, 19.6310467981989))
    (11.931275122466, 18.0462702532436))
    (6.56491029696013, 21.5098251711267))
    (8.87507602702847, 21.4823134390704))
    (11.0118378554584, 20.9773232834654))
class Point(dx: Double, dy: Double) {
  val x: Double = dx
  val y: Double = dy

  override def toString(): String = {
    "(" + x + ", " + y + ")"
  }

  def dist(p: Point): Double = {
    return (x - p.x) * (x - p.x) + (y - p.y) * (y - p.y);
  }
}

object kmeans extends App {
  val k: Int = 2

  // Correct answers to centers are (10, 20) and (25, 5)
  val points: List[Point] = List(
    new Point(10.8348626966492, 18.7800980127523),
    new Point(10.259545802461, 23.4515683763173),
    new Point(11.7396623802245, 17.7026240456956),
    new Point(12.4277617893575, 19.4887691804508),
    new Point(10.1057940183815, 18.7332929859685),
    new Point(11.0118378554584, 20.9773232834654),
    new Point(7.03171204763376, 19.1985058633283),
    new Point(6.56491029696013, 21.5098251711267),
    new Point(10.7751248849735, 22.1517666115673),
    new Point(8.90149362263775, 19.6314465074203),
    new Point(11.931275122466, 18.0462702532436),
    new Point(11.7265904596619, 16.9636039793709),
    new Point(11.7493214262468, 17.8517235677469),
    new Point(12.4353462881232, 19.6310467981989),
    new Point(13.0838514816799, 20.3398794353494),
    new Point(7.7875624720831, 20.1569764307574),
    new Point(11.9096128931784, 21.1855674228972),
    new Point(8.87507602702847, 21.4823134390704),
    new Point(7.91362116378194, 21.325928219919),
    new Point(26.4748241341303, 9.25128245838802),
    new Point(26.2100410238973, 5.06220487544192),
    new Point(28.1587146197594, 3.70625885635717),
    new Point(26.8942422516129, 5.02646862012427),
    new Point(23.7770902983858, 7.19445492687232),
    new Point(23.6587920739353, 3.35476798095758),
    new Point(23.7722765903534, 3.74873642284525),
    new Point(23.9177161897547, 8.1377950229489),
    new Point(22.4668345067162, 8.9705504626857),
    new Point(24.5349708443852, 5.00561881333415),
    new Point(24.3793349065557, 4.59761596097384),
    new Point(27.0339042858296, 4.4151109960116),
    new Point(21.8031183153743, 5.69297814349064),
    new Point(22.636600400773, 2.46561420928429),
    new Point(25.1439536911272, 3.58469981317611),
    new Point(21.4930923464916, 3.28999356823389),
    new Point(23.5359486724204, 4.07290025106778),
    new Point(22.5447925324242, 2.99485404382734),
    new Point(25.4645673159779, 7.54703465191098)).sortBy(
      p => (p.x + " " + p.y).hashCode())

  def clusterMean(points: List[Point]): Point = {
    val cumulative = points.reduceLeft((a: Point, b: Point) => new Point(a.x + b.x, a.y + b.y))

    return new Point(cumulative.x / points.length, cumulative.y / points.length)
  }

  def render(points: Map[Int, List[Point]]) {
    for (clusterNumber  x._2 % k) transform (
        (i: Int, p: List[(Point, Int)]) => for (x  clusters.map(cluster => cluster._1)

    // find cluster means
    val means =
      (clusters: Map[Int, List[Point]]) =>
        for (clusterIndex  closest(p, means(clusters)))

    render(newClusters)

    return newClusters
  }

  var clusterToTest = clusters
  for (i