## Implementing k-means in Scala

To generate sample data, I selected two points, (10, 20) and (25, 5), then generated a list of normally distributed points around those two – the exact points used are in the code below.

This implements Lloyd’s algorithm, which tries to cluster points in iterations in a simple manner:

1. Assume a certain number of clusters
2. Group the points at random
3. Compute the center of each cluster
4. For each point, compute which cluster is closest
5. Move all the points into new groupings
6. Repeat 3-5 a few times, until you’re happy with the results

I like how the functional programming style forces you to recreate all the data structures, in this case. It might be tempting to implement this in an imperative style, modifying data structures in place, but since steps 4-5 require separate data, you are protected against making it more difficult. You can see the full source below, or on github.

Since this example is fairly contrived, this converges pretty quickly:

```Initial State:
Cluster 0
Mean: (17.83517750970944, 12.242720407317105)
(10.8348626966492, 18.7800980127523))
(7.7875624720831, 20.1569764307574))
(11.9096128931784, 21.1855674228972))
(22.4668345067162, 8.9705504626857))
(7.91362116378194, 21.325928219919))
(22.636600400773, 2.46561420928429))
(13.0838514816799, 20.3398794353494))
(11.7396623802245, 17.7026240456956))
(25.1439536911272, 3.58469981317611))
(23.5359486724204, 4.07290025106778))
(11.7493214262468, 17.8517235677469))
(12.4277617893575, 19.4887691804508))
(11.931275122466, 18.0462702532436))
(25.4645673159779, 7.54703465191098))
(21.8031183153743, 5.69297814349064))
(23.9177161897547, 8.1377950229489))
(24.5349708443852, 5.00561881333415))
(26.2100410238973, 5.06220487544192))
(23.7770902983858, 7.19445492687232))

Cluster 1
Mean: (16.95249500233747, 12.848199048608048)
(11.7265904596619, 16.9636039793709))
(10.7751248849735, 22.1517666115673))
(23.6587920739353, 3.35476798095758))
(21.4930923464916, 3.28999356823389))
(26.4748241341303, 9.25128245838802))
(7.03171204763376, 19.1985058633283))
(23.7722765903534, 3.74873642284525))
(10.259545802461, 23.4515683763173))
(28.1587146197594, 3.70625885635717))
(10.1057940183815, 18.7332929859685))
(8.90149362263775, 19.6314465074203))
(12.4353462881232, 19.6310467981989))
(24.3793349065557, 4.59761596097384))
(22.5447925324242, 2.99485404382734))
(26.8942422516129, 5.02646862012427))
(6.56491029696013, 21.5098251711267))
(8.87507602702847, 21.4823134390704))
(27.0339042858296, 4.4151109960116))
(11.0118378554584, 20.9773232834654))

Iteration: 0
Cluster 0
Mean: (23.781370272978315, 5.754127202865132)
(11.7265904596619, 16.9636039793709))
(23.6587920739353, 3.35476798095758))
(22.4668345067162, 8.9705504626857))
(21.4930923464916, 3.28999356823389))
(26.4748241341303, 9.25128245838802))
(22.636600400773, 2.46561420928429))
(23.7722765903534, 3.74873642284525))
(25.1439536911272, 3.58469981317611))
(28.1587146197594, 3.70625885635717))
(23.5359486724204, 4.07290025106778))
(24.3793349065557, 4.59761596097384))
(25.4645673159779, 7.54703465191098))
(22.5447925324242, 2.99485404382734))
(21.8031183153743, 5.69297814349064))
(26.8942422516129, 5.02646862012427))
(23.9177161897547, 8.1377950229489))
(24.5349708443852, 5.00561881333415))
(26.2100410238973, 5.06220487544192))
(27.0339042858296, 4.4151109960116))
(23.7770902983858, 7.19445492687232))

Cluster 1
Mean: (10.296576237184727, 20.09138475584863)
(10.8348626966492, 18.7800980127523))
(7.7875624720831, 20.1569764307574))
(10.7751248849735, 22.1517666115673))
(11.9096128931784, 21.1855674228972))
(7.91362116378194, 21.325928219919))
(7.03171204763376, 19.1985058633283))
(13.0838514816799, 20.3398794353494))
(11.7396623802245, 17.7026240456956))
(10.259545802461, 23.4515683763173))
(10.1057940183815, 18.7332929859685))
(11.7493214262468, 17.8517235677469))
(8.90149362263775, 19.6314465074203))
(12.4277617893575, 19.4887691804508))
(12.4353462881232, 19.6310467981989))
(11.931275122466, 18.0462702532436))
(6.56491029696013, 21.5098251711267))
(8.87507602702847, 21.4823134390704))
(11.0118378554584, 20.9773232834654))

Iteration: 1
Cluster 0
Mean: (24.415832368416023, 5.164154740943777)
(23.6587920739353, 3.35476798095758))
(22.4668345067162, 8.9705504626857))
(21.4930923464916, 3.28999356823389))
(26.4748241341303, 9.25128245838802))
(22.636600400773, 2.46561420928429))
(23.7722765903534, 3.74873642284525))
(25.1439536911272, 3.58469981317611))
(28.1587146197594, 3.70625885635717))
(23.5359486724204, 4.07290025106778))
(24.3793349065557, 4.59761596097384))
(25.4645673159779, 7.54703465191098))
(22.5447925324242, 2.99485404382734))
(21.8031183153743, 5.69297814349064))
(26.8942422516129, 5.02646862012427))
(23.9177161897547, 8.1377950229489))
(24.5349708443852, 5.00561881333415))
(26.2100410238973, 5.06220487544192))
(27.0339042858296, 4.4151109960116))
(23.7770902983858, 7.19445492687232))

Cluster 1
Mean: (10.371840143630894, 19.92676471498138)
(10.8348626966492, 18.7800980127523))
(11.7265904596619, 16.9636039793709))
(7.7875624720831, 20.1569764307574))
(10.7751248849735, 22.1517666115673))
(11.9096128931784, 21.1855674228972))
(7.91362116378194, 21.325928219919))
(7.03171204763376, 19.1985058633283))
(13.0838514816799, 20.3398794353494))
(11.7396623802245, 17.7026240456956))
(10.259545802461, 23.4515683763173))
(10.1057940183815, 18.7332929859685))
(11.7493214262468, 17.8517235677469))
(8.90149362263775, 19.6314465074203))
(12.4277617893575, 19.4887691804508))
(12.4353462881232, 19.6310467981989))
(11.931275122466, 18.0462702532436))
(6.56491029696013, 21.5098251711267))
(8.87507602702847, 21.4823134390704))
(11.0118378554584, 20.9773232834654))

Iteration: 2
Cluster 0
Mean: (24.415832368416023, 5.164154740943777)
(23.6587920739353, 3.35476798095758))
(22.4668345067162, 8.9705504626857))
(21.4930923464916, 3.28999356823389))
(26.4748241341303, 9.25128245838802))
(22.636600400773, 2.46561420928429))
(23.7722765903534, 3.74873642284525))
(25.1439536911272, 3.58469981317611))
(28.1587146197594, 3.70625885635717))
(23.5359486724204, 4.07290025106778))
(24.3793349065557, 4.59761596097384))
(25.4645673159779, 7.54703465191098))
(22.5447925324242, 2.99485404382734))
(21.8031183153743, 5.69297814349064))
(26.8942422516129, 5.02646862012427))
(23.9177161897547, 8.1377950229489))
(24.5349708443852, 5.00561881333415))
(26.2100410238973, 5.06220487544192))
(27.0339042858296, 4.4151109960116))
(23.7770902983858, 7.19445492687232))

Cluster 1
Mean: (10.371840143630894, 19.92676471498138)
(10.8348626966492, 18.7800980127523))
(11.7265904596619, 16.9636039793709))
(7.7875624720831, 20.1569764307574))
(10.7751248849735, 22.1517666115673))
(11.9096128931784, 21.1855674228972))
(7.91362116378194, 21.325928219919))
(7.03171204763376, 19.1985058633283))
(13.0838514816799, 20.3398794353494))
(11.7396623802245, 17.7026240456956))
(10.259545802461, 23.4515683763173))
(10.1057940183815, 18.7332929859685))
(11.7493214262468, 17.8517235677469))
(8.90149362263775, 19.6314465074203))
(12.4277617893575, 19.4887691804508))
(12.4353462881232, 19.6310467981989))
(11.931275122466, 18.0462702532436))
(6.56491029696013, 21.5098251711267))
(8.87507602702847, 21.4823134390704))
(11.0118378554584, 20.9773232834654))

Iteration: 3
Cluster 0
Mean: (24.415832368416023, 5.164154740943777)
(23.6587920739353, 3.35476798095758))
(22.4668345067162, 8.9705504626857))
(21.4930923464916, 3.28999356823389))
(26.4748241341303, 9.25128245838802))
(22.636600400773, 2.46561420928429))
(23.7722765903534, 3.74873642284525))
(25.1439536911272, 3.58469981317611))
(28.1587146197594, 3.70625885635717))
(23.5359486724204, 4.07290025106778))
(24.3793349065557, 4.59761596097384))
(25.4645673159779, 7.54703465191098))
(22.5447925324242, 2.99485404382734))
(21.8031183153743, 5.69297814349064))
(26.8942422516129, 5.02646862012427))
(23.9177161897547, 8.1377950229489))
(24.5349708443852, 5.00561881333415))
(26.2100410238973, 5.06220487544192))
(27.0339042858296, 4.4151109960116))
(23.7770902983858, 7.19445492687232))

Cluster 1
Mean: (10.371840143630894, 19.92676471498138)
(10.8348626966492, 18.7800980127523))
(11.7265904596619, 16.9636039793709))
(7.7875624720831, 20.1569764307574))
(10.7751248849735, 22.1517666115673))
(11.9096128931784, 21.1855674228972))
(7.91362116378194, 21.325928219919))
(7.03171204763376, 19.1985058633283))
(13.0838514816799, 20.3398794353494))
(11.7396623802245, 17.7026240456956))
(10.259545802461, 23.4515683763173))
(10.1057940183815, 18.7332929859685))
(11.7493214262468, 17.8517235677469))
(8.90149362263775, 19.6314465074203))
(12.4277617893575, 19.4887691804508))
(12.4353462881232, 19.6310467981989))
(11.931275122466, 18.0462702532436))
(6.56491029696013, 21.5098251711267))
(8.87507602702847, 21.4823134390704))
(11.0118378554584, 20.9773232834654))

Iteration: 4
Cluster 0
Mean: (24.415832368416023, 5.164154740943777)
(23.6587920739353, 3.35476798095758))
(22.4668345067162, 8.9705504626857))
(21.4930923464916, 3.28999356823389))
(26.4748241341303, 9.25128245838802))
(22.636600400773, 2.46561420928429))
(23.7722765903534, 3.74873642284525))
(25.1439536911272, 3.58469981317611))
(28.1587146197594, 3.70625885635717))
(23.5359486724204, 4.07290025106778))
(24.3793349065557, 4.59761596097384))
(25.4645673159779, 7.54703465191098))
(22.5447925324242, 2.99485404382734))
(21.8031183153743, 5.69297814349064))
(26.8942422516129, 5.02646862012427))
(23.9177161897547, 8.1377950229489))
(24.5349708443852, 5.00561881333415))
(26.2100410238973, 5.06220487544192))
(27.0339042858296, 4.4151109960116))
(23.7770902983858, 7.19445492687232))

Cluster 1
Mean: (10.371840143630894, 19.92676471498138)
(10.8348626966492, 18.7800980127523))
(11.7265904596619, 16.9636039793709))
(7.7875624720831, 20.1569764307574))
(10.7751248849735, 22.1517666115673))
(11.9096128931784, 21.1855674228972))
(7.91362116378194, 21.325928219919))
(7.03171204763376, 19.1985058633283))
(13.0838514816799, 20.3398794353494))
(11.7396623802245, 17.7026240456956))
(10.259545802461, 23.4515683763173))
(10.1057940183815, 18.7332929859685))
(11.7493214262468, 17.8517235677469))
(8.90149362263775, 19.6314465074203))
(12.4277617893575, 19.4887691804508))
(12.4353462881232, 19.6310467981989))
(11.931275122466, 18.0462702532436))
(6.56491029696013, 21.5098251711267))
(8.87507602702847, 21.4823134390704))
(11.0118378554584, 20.9773232834654))

Iteration: 5
Cluster 0
Mean: (24.415832368416023, 5.164154740943777)
(23.6587920739353, 3.35476798095758))
(22.4668345067162, 8.9705504626857))
(21.4930923464916, 3.28999356823389))
(26.4748241341303, 9.25128245838802))
(22.636600400773, 2.46561420928429))
(23.7722765903534, 3.74873642284525))
(25.1439536911272, 3.58469981317611))
(28.1587146197594, 3.70625885635717))
(23.5359486724204, 4.07290025106778))
(24.3793349065557, 4.59761596097384))
(25.4645673159779, 7.54703465191098))
(22.5447925324242, 2.99485404382734))
(21.8031183153743, 5.69297814349064))
(26.8942422516129, 5.02646862012427))
(23.9177161897547, 8.1377950229489))
(24.5349708443852, 5.00561881333415))
(26.2100410238973, 5.06220487544192))
(27.0339042858296, 4.4151109960116))
(23.7770902983858, 7.19445492687232))

Cluster 1
Mean: (10.371840143630894, 19.92676471498138)
(10.8348626966492, 18.7800980127523))
(11.7265904596619, 16.9636039793709))
(7.7875624720831, 20.1569764307574))
(10.7751248849735, 22.1517666115673))
(11.9096128931784, 21.1855674228972))
(7.91362116378194, 21.325928219919))
(7.03171204763376, 19.1985058633283))
(13.0838514816799, 20.3398794353494))
(11.7396623802245, 17.7026240456956))
(10.259545802461, 23.4515683763173))
(10.1057940183815, 18.7332929859685))
(11.7493214262468, 17.8517235677469))
(8.90149362263775, 19.6314465074203))
(12.4277617893575, 19.4887691804508))
(12.4353462881232, 19.6310467981989))
(11.931275122466, 18.0462702532436))
(6.56491029696013, 21.5098251711267))
(8.87507602702847, 21.4823134390704))
(11.0118378554584, 20.9773232834654))```
```class Point(dx: Double, dy: Double) {
val x: Double = dx
val y: Double = dy

override def toString(): String = {
"(" + x + ", " + y + ")"
}

def dist(p: Point): Double = {
return (x - p.x) * (x - p.x) + (y - p.y) * (y - p.y);
}
}

object kmeans extends App {
val k: Int = 2

// Correct answers to centers are (10, 20) and (25, 5)
val points: List[Point] = List(
new Point(10.8348626966492, 18.7800980127523),
new Point(10.259545802461, 23.4515683763173),
new Point(11.7396623802245, 17.7026240456956),
new Point(12.4277617893575, 19.4887691804508),
new Point(10.1057940183815, 18.7332929859685),
new Point(11.0118378554584, 20.9773232834654),
new Point(7.03171204763376, 19.1985058633283),
new Point(6.56491029696013, 21.5098251711267),
new Point(10.7751248849735, 22.1517666115673),
new Point(8.90149362263775, 19.6314465074203),
new Point(11.931275122466, 18.0462702532436),
new Point(11.7265904596619, 16.9636039793709),
new Point(11.7493214262468, 17.8517235677469),
new Point(12.4353462881232, 19.6310467981989),
new Point(13.0838514816799, 20.3398794353494),
new Point(7.7875624720831, 20.1569764307574),
new Point(11.9096128931784, 21.1855674228972),
new Point(8.87507602702847, 21.4823134390704),
new Point(7.91362116378194, 21.325928219919),
new Point(26.4748241341303, 9.25128245838802),
new Point(26.2100410238973, 5.06220487544192),
new Point(28.1587146197594, 3.70625885635717),
new Point(26.8942422516129, 5.02646862012427),
new Point(23.7770902983858, 7.19445492687232),
new Point(23.6587920739353, 3.35476798095758),
new Point(23.7722765903534, 3.74873642284525),
new Point(23.9177161897547, 8.1377950229489),
new Point(22.4668345067162, 8.9705504626857),
new Point(24.5349708443852, 5.00561881333415),
new Point(24.3793349065557, 4.59761596097384),
new Point(27.0339042858296, 4.4151109960116),
new Point(21.8031183153743, 5.69297814349064),
new Point(22.636600400773, 2.46561420928429),
new Point(25.1439536911272, 3.58469981317611),
new Point(21.4930923464916, 3.28999356823389),
new Point(23.5359486724204, 4.07290025106778),
new Point(22.5447925324242, 2.99485404382734),
new Point(25.4645673159779, 7.54703465191098)).sortBy(
p => (p.x + " " + p.y).hashCode())

def clusterMean(points: List[Point]): Point = {
val cumulative = points.reduceLeft((a: Point, b: Point) => new Point(a.x + b.x, a.y + b.y))

return new Point(cumulative.x / points.length, cumulative.y / points.length)
}

def render(points: Map[Int, List[Point]]) {
for (clusterNumber  x._2 % k) transform (
(i: Int, p: List[(Point, Int)]) => for (x  clusters.map(cluster => cluster._1)

// find cluster means
val means =
(clusters: Map[Int, List[Point]]) =>
for (clusterIndex  closest(p, means(clusters)))

render(newClusters)

return newClusters
}

var clusterToTest = clusters
for (i```