Implementing k-means in Scala

To generate sample data, I selected two points, (10, 20) and (25, 5), then generated a list of normally distributed points around those two – the exact points used are in the code below.

This implements Lloyd’s algorithm, which tries to cluster points in iterations in a simple manner:

1. Assume a certain number of clusters
2. Group the points at random
3. Compute the center of each cluster
4. For each point, compute which cluster is closest
5. Move all the points into new groupings
6. Repeat 3-5 a few times, until you’re happy with the results

I like how the functional programming style forces you to recreate all the data structures, in this case. It might be tempting to implement this in an imperative style, modifying data structures in place, but since steps 4-5 require separate data, you are protected against making it more difficult. You can see the full source below, or on github.

Since this example is fairly contrived, this converges pretty quickly:

Initial State: 
  Cluster 0
  Mean: (17.83517750970944, 12.242720407317105)
    (10.8348626966492, 18.7800980127523))
    (7.7875624720831, 20.1569764307574))
    (11.9096128931784, 21.1855674228972))
    (22.4668345067162, 8.9705504626857))
    (7.91362116378194, 21.325928219919))
    (22.636600400773, 2.46561420928429))
    (13.0838514816799, 20.3398794353494))
    (11.7396623802245, 17.7026240456956))
    (25.1439536911272, 3.58469981317611))
    (23.5359486724204, 4.07290025106778))
    (11.7493214262468, 17.8517235677469))
    (12.4277617893575, 19.4887691804508))
    (11.931275122466, 18.0462702532436))
    (25.4645673159779, 7.54703465191098))
    (21.8031183153743, 5.69297814349064))
    (23.9177161897547, 8.1377950229489))
    (24.5349708443852, 5.00561881333415))
    (26.2100410238973, 5.06220487544192))
    (23.7770902983858, 7.19445492687232))

  Cluster 1
  Mean: (16.95249500233747, 12.848199048608048)
    (11.7265904596619, 16.9636039793709))
    (10.7751248849735, 22.1517666115673))
    (23.6587920739353, 3.35476798095758))
    (21.4930923464916, 3.28999356823389))
    (26.4748241341303, 9.25128245838802))
    (7.03171204763376, 19.1985058633283))
    (23.7722765903534, 3.74873642284525))
    (10.259545802461, 23.4515683763173))
    (28.1587146197594, 3.70625885635717))
    (10.1057940183815, 18.7332929859685))
    (8.90149362263775, 19.6314465074203))
    (12.4353462881232, 19.6310467981989))
    (24.3793349065557, 4.59761596097384))
    (22.5447925324242, 2.99485404382734))
    (26.8942422516129, 5.02646862012427))
    (6.56491029696013, 21.5098251711267))
    (8.87507602702847, 21.4823134390704))
    (27.0339042858296, 4.4151109960116))
    (11.0118378554584, 20.9773232834654))

Iteration: 0
  Cluster 0
  Mean: (23.781370272978315, 5.754127202865132)
    (11.7265904596619, 16.9636039793709))
    (23.6587920739353, 3.35476798095758))
    (22.4668345067162, 8.9705504626857))
    (21.4930923464916, 3.28999356823389))
    (26.4748241341303, 9.25128245838802))
    (22.636600400773, 2.46561420928429))
    (23.7722765903534, 3.74873642284525))
    (25.1439536911272, 3.58469981317611))
    (28.1587146197594, 3.70625885635717))
    (23.5359486724204, 4.07290025106778))
    (24.3793349065557, 4.59761596097384))
    (25.4645673159779, 7.54703465191098))
    (22.5447925324242, 2.99485404382734))
    (21.8031183153743, 5.69297814349064))
    (26.8942422516129, 5.02646862012427))
    (23.9177161897547, 8.1377950229489))
    (24.5349708443852, 5.00561881333415))
    (26.2100410238973, 5.06220487544192))
    (27.0339042858296, 4.4151109960116))
    (23.7770902983858, 7.19445492687232))

  Cluster 1
  Mean: (10.296576237184727, 20.09138475584863)
    (10.8348626966492, 18.7800980127523))
    (7.7875624720831, 20.1569764307574))
    (10.7751248849735, 22.1517666115673))
    (11.9096128931784, 21.1855674228972))
    (7.91362116378194, 21.325928219919))
    (7.03171204763376, 19.1985058633283))
    (13.0838514816799, 20.3398794353494))
    (11.7396623802245, 17.7026240456956))
    (10.259545802461, 23.4515683763173))
    (10.1057940183815, 18.7332929859685))
    (11.7493214262468, 17.8517235677469))
    (8.90149362263775, 19.6314465074203))
    (12.4277617893575, 19.4887691804508))
    (12.4353462881232, 19.6310467981989))
    (11.931275122466, 18.0462702532436))
    (6.56491029696013, 21.5098251711267))
    (8.87507602702847, 21.4823134390704))
    (11.0118378554584, 20.9773232834654))

Iteration: 1
  Cluster 0
  Mean: (24.415832368416023, 5.164154740943777)
    (23.6587920739353, 3.35476798095758))
    (22.4668345067162, 8.9705504626857))
    (21.4930923464916, 3.28999356823389))
    (26.4748241341303, 9.25128245838802))
    (22.636600400773, 2.46561420928429))
    (23.7722765903534, 3.74873642284525))
    (25.1439536911272, 3.58469981317611))
    (28.1587146197594, 3.70625885635717))
    (23.5359486724204, 4.07290025106778))
    (24.3793349065557, 4.59761596097384))
    (25.4645673159779, 7.54703465191098))
    (22.5447925324242, 2.99485404382734))
    (21.8031183153743, 5.69297814349064))
    (26.8942422516129, 5.02646862012427))
    (23.9177161897547, 8.1377950229489))
    (24.5349708443852, 5.00561881333415))
    (26.2100410238973, 5.06220487544192))
    (27.0339042858296, 4.4151109960116))
    (23.7770902983858, 7.19445492687232))

  Cluster 1
  Mean: (10.371840143630894, 19.92676471498138)
    (10.8348626966492, 18.7800980127523))
    (11.7265904596619, 16.9636039793709))
    (7.7875624720831, 20.1569764307574))
    (10.7751248849735, 22.1517666115673))
    (11.9096128931784, 21.1855674228972))
    (7.91362116378194, 21.325928219919))
    (7.03171204763376, 19.1985058633283))
    (13.0838514816799, 20.3398794353494))
    (11.7396623802245, 17.7026240456956))
    (10.259545802461, 23.4515683763173))
    (10.1057940183815, 18.7332929859685))
    (11.7493214262468, 17.8517235677469))
    (8.90149362263775, 19.6314465074203))
    (12.4277617893575, 19.4887691804508))
    (12.4353462881232, 19.6310467981989))
    (11.931275122466, 18.0462702532436))
    (6.56491029696013, 21.5098251711267))
    (8.87507602702847, 21.4823134390704))
    (11.0118378554584, 20.9773232834654))

Iteration: 2
  Cluster 0
  Mean: (24.415832368416023, 5.164154740943777)
    (23.6587920739353, 3.35476798095758))
    (22.4668345067162, 8.9705504626857))
    (21.4930923464916, 3.28999356823389))
    (26.4748241341303, 9.25128245838802))
    (22.636600400773, 2.46561420928429))
    (23.7722765903534, 3.74873642284525))
    (25.1439536911272, 3.58469981317611))
    (28.1587146197594, 3.70625885635717))
    (23.5359486724204, 4.07290025106778))
    (24.3793349065557, 4.59761596097384))
    (25.4645673159779, 7.54703465191098))
    (22.5447925324242, 2.99485404382734))
    (21.8031183153743, 5.69297814349064))
    (26.8942422516129, 5.02646862012427))
    (23.9177161897547, 8.1377950229489))
    (24.5349708443852, 5.00561881333415))
    (26.2100410238973, 5.06220487544192))
    (27.0339042858296, 4.4151109960116))
    (23.7770902983858, 7.19445492687232))

  Cluster 1
  Mean: (10.371840143630894, 19.92676471498138)
    (10.8348626966492, 18.7800980127523))
    (11.7265904596619, 16.9636039793709))
    (7.7875624720831, 20.1569764307574))
    (10.7751248849735, 22.1517666115673))
    (11.9096128931784, 21.1855674228972))
    (7.91362116378194, 21.325928219919))
    (7.03171204763376, 19.1985058633283))
    (13.0838514816799, 20.3398794353494))
    (11.7396623802245, 17.7026240456956))
    (10.259545802461, 23.4515683763173))
    (10.1057940183815, 18.7332929859685))
    (11.7493214262468, 17.8517235677469))
    (8.90149362263775, 19.6314465074203))
    (12.4277617893575, 19.4887691804508))
    (12.4353462881232, 19.6310467981989))
    (11.931275122466, 18.0462702532436))
    (6.56491029696013, 21.5098251711267))
    (8.87507602702847, 21.4823134390704))
    (11.0118378554584, 20.9773232834654))

Iteration: 3
  Cluster 0
  Mean: (24.415832368416023, 5.164154740943777)
    (23.6587920739353, 3.35476798095758))
    (22.4668345067162, 8.9705504626857))
    (21.4930923464916, 3.28999356823389))
    (26.4748241341303, 9.25128245838802))
    (22.636600400773, 2.46561420928429))
    (23.7722765903534, 3.74873642284525))
    (25.1439536911272, 3.58469981317611))
    (28.1587146197594, 3.70625885635717))
    (23.5359486724204, 4.07290025106778))
    (24.3793349065557, 4.59761596097384))
    (25.4645673159779, 7.54703465191098))
    (22.5447925324242, 2.99485404382734))
    (21.8031183153743, 5.69297814349064))
    (26.8942422516129, 5.02646862012427))
    (23.9177161897547, 8.1377950229489))
    (24.5349708443852, 5.00561881333415))
    (26.2100410238973, 5.06220487544192))
    (27.0339042858296, 4.4151109960116))
    (23.7770902983858, 7.19445492687232))

  Cluster 1
  Mean: (10.371840143630894, 19.92676471498138)
    (10.8348626966492, 18.7800980127523))
    (11.7265904596619, 16.9636039793709))
    (7.7875624720831, 20.1569764307574))
    (10.7751248849735, 22.1517666115673))
    (11.9096128931784, 21.1855674228972))
    (7.91362116378194, 21.325928219919))
    (7.03171204763376, 19.1985058633283))
    (13.0838514816799, 20.3398794353494))
    (11.7396623802245, 17.7026240456956))
    (10.259545802461, 23.4515683763173))
    (10.1057940183815, 18.7332929859685))
    (11.7493214262468, 17.8517235677469))
    (8.90149362263775, 19.6314465074203))
    (12.4277617893575, 19.4887691804508))
    (12.4353462881232, 19.6310467981989))
    (11.931275122466, 18.0462702532436))
    (6.56491029696013, 21.5098251711267))
    (8.87507602702847, 21.4823134390704))
    (11.0118378554584, 20.9773232834654))

Iteration: 4
  Cluster 0
  Mean: (24.415832368416023, 5.164154740943777)
    (23.6587920739353, 3.35476798095758))
    (22.4668345067162, 8.9705504626857))
    (21.4930923464916, 3.28999356823389))
    (26.4748241341303, 9.25128245838802))
    (22.636600400773, 2.46561420928429))
    (23.7722765903534, 3.74873642284525))
    (25.1439536911272, 3.58469981317611))
    (28.1587146197594, 3.70625885635717))
    (23.5359486724204, 4.07290025106778))
    (24.3793349065557, 4.59761596097384))
    (25.4645673159779, 7.54703465191098))
    (22.5447925324242, 2.99485404382734))
    (21.8031183153743, 5.69297814349064))
    (26.8942422516129, 5.02646862012427))
    (23.9177161897547, 8.1377950229489))
    (24.5349708443852, 5.00561881333415))
    (26.2100410238973, 5.06220487544192))
    (27.0339042858296, 4.4151109960116))
    (23.7770902983858, 7.19445492687232))

  Cluster 1
  Mean: (10.371840143630894, 19.92676471498138)
    (10.8348626966492, 18.7800980127523))
    (11.7265904596619, 16.9636039793709))
    (7.7875624720831, 20.1569764307574))
    (10.7751248849735, 22.1517666115673))
    (11.9096128931784, 21.1855674228972))
    (7.91362116378194, 21.325928219919))
    (7.03171204763376, 19.1985058633283))
    (13.0838514816799, 20.3398794353494))
    (11.7396623802245, 17.7026240456956))
    (10.259545802461, 23.4515683763173))
    (10.1057940183815, 18.7332929859685))
    (11.7493214262468, 17.8517235677469))
    (8.90149362263775, 19.6314465074203))
    (12.4277617893575, 19.4887691804508))
    (12.4353462881232, 19.6310467981989))
    (11.931275122466, 18.0462702532436))
    (6.56491029696013, 21.5098251711267))
    (8.87507602702847, 21.4823134390704))
    (11.0118378554584, 20.9773232834654))

Iteration: 5
  Cluster 0
  Mean: (24.415832368416023, 5.164154740943777)
    (23.6587920739353, 3.35476798095758))
    (22.4668345067162, 8.9705504626857))
    (21.4930923464916, 3.28999356823389))
    (26.4748241341303, 9.25128245838802))
    (22.636600400773, 2.46561420928429))
    (23.7722765903534, 3.74873642284525))
    (25.1439536911272, 3.58469981317611))
    (28.1587146197594, 3.70625885635717))
    (23.5359486724204, 4.07290025106778))
    (24.3793349065557, 4.59761596097384))
    (25.4645673159779, 7.54703465191098))
    (22.5447925324242, 2.99485404382734))
    (21.8031183153743, 5.69297814349064))
    (26.8942422516129, 5.02646862012427))
    (23.9177161897547, 8.1377950229489))
    (24.5349708443852, 5.00561881333415))
    (26.2100410238973, 5.06220487544192))
    (27.0339042858296, 4.4151109960116))
    (23.7770902983858, 7.19445492687232))

  Cluster 1
  Mean: (10.371840143630894, 19.92676471498138)
    (10.8348626966492, 18.7800980127523))
    (11.7265904596619, 16.9636039793709))
    (7.7875624720831, 20.1569764307574))
    (10.7751248849735, 22.1517666115673))
    (11.9096128931784, 21.1855674228972))
    (7.91362116378194, 21.325928219919))
    (7.03171204763376, 19.1985058633283))
    (13.0838514816799, 20.3398794353494))
    (11.7396623802245, 17.7026240456956))
    (10.259545802461, 23.4515683763173))
    (10.1057940183815, 18.7332929859685))
    (11.7493214262468, 17.8517235677469))
    (8.90149362263775, 19.6314465074203))
    (12.4277617893575, 19.4887691804508))
    (12.4353462881232, 19.6310467981989))
    (11.931275122466, 18.0462702532436))
    (6.56491029696013, 21.5098251711267))
    (8.87507602702847, 21.4823134390704))
    (11.0118378554584, 20.9773232834654))
class Point(dx: Double, dy: Double) {
  val x: Double = dx
  val y: Double = dy

  override def toString(): String = {
    "(" + x + ", " + y + ")"
  }

  def dist(p: Point): Double = {
    return (x - p.x) * (x - p.x) + (y - p.y) * (y - p.y);
  }
}

object kmeans extends App {
  val k: Int = 2

  // Correct answers to centers are (10, 20) and (25, 5)
  val points: List[Point] = List(
    new Point(10.8348626966492, 18.7800980127523),
    new Point(10.259545802461, 23.4515683763173),
    new Point(11.7396623802245, 17.7026240456956),
    new Point(12.4277617893575, 19.4887691804508),
    new Point(10.1057940183815, 18.7332929859685),
    new Point(11.0118378554584, 20.9773232834654),
    new Point(7.03171204763376, 19.1985058633283),
    new Point(6.56491029696013, 21.5098251711267),
    new Point(10.7751248849735, 22.1517666115673),
    new Point(8.90149362263775, 19.6314465074203),
    new Point(11.931275122466, 18.0462702532436),
    new Point(11.7265904596619, 16.9636039793709),
    new Point(11.7493214262468, 17.8517235677469),
    new Point(12.4353462881232, 19.6310467981989),
    new Point(13.0838514816799, 20.3398794353494),
    new Point(7.7875624720831, 20.1569764307574),
    new Point(11.9096128931784, 21.1855674228972),
    new Point(8.87507602702847, 21.4823134390704),
    new Point(7.91362116378194, 21.325928219919),
    new Point(26.4748241341303, 9.25128245838802),
    new Point(26.2100410238973, 5.06220487544192),
    new Point(28.1587146197594, 3.70625885635717),
    new Point(26.8942422516129, 5.02646862012427),
    new Point(23.7770902983858, 7.19445492687232),
    new Point(23.6587920739353, 3.35476798095758),
    new Point(23.7722765903534, 3.74873642284525),
    new Point(23.9177161897547, 8.1377950229489),
    new Point(22.4668345067162, 8.9705504626857),
    new Point(24.5349708443852, 5.00561881333415),
    new Point(24.3793349065557, 4.59761596097384),
    new Point(27.0339042858296, 4.4151109960116),
    new Point(21.8031183153743, 5.69297814349064),
    new Point(22.636600400773, 2.46561420928429),
    new Point(25.1439536911272, 3.58469981317611),
    new Point(21.4930923464916, 3.28999356823389),
    new Point(23.5359486724204, 4.07290025106778),
    new Point(22.5447925324242, 2.99485404382734),
    new Point(25.4645673159779, 7.54703465191098)).sortBy(
      p => (p.x + " " + p.y).hashCode())

  def clusterMean(points: List[Point]): Point = {
    val cumulative = points.reduceLeft((a: Point, b: Point) => new Point(a.x + b.x, a.y + b.y))

    return new Point(cumulative.x / points.length, cumulative.y / points.length)
  }

  def render(points: Map[Int, List[Point]]) {
    for (clusterNumber  x._2 % k) transform (
        (i: Int, p: List[(Point, Int)]) => for (x  clusters.map(cluster => cluster._1)

    // find cluster means
    val means =
      (clusters: Map[Int, List[Point]]) =>
        for (clusterIndex  closest(p, means(clusters)))

    render(newClusters)

    return newClusters
  }

  var clusterToTest = clusters
  for (i

5 Replies to “Implementing k-means in Scala”

  1. Interesting entry. Just some remarks:

    You can write `case class Point(x: Double, y: Double)` so you don’t have to write public getters for the coordinates; then you can leave out the `new` keyword, e.g. `Point(10.835, 18.780)`.

    There are a few more ‘javaish’ things, e.g. you should omit the `return` keyword, and you can write `println` instead of `System.out.println`.

  2. in Class “Point”, your “dist” function is not euclidian distance. Is that intended?

    What you have there amounts to the square of the cosine of the vectors from the origin to each points, which seems rather odd.

Leave a Reply to Gary Cancel reply

Your email address will not be published. Required fields are marked *