From 549bbfbccf488d9de897e0fb7e4cdb5f66cba5ba Mon Sep 17 00:00:00 2001 From: Dennis Eichhorn Date: Thu, 17 Oct 2019 16:09:03 +0200 Subject: [PATCH] Implement kmeans algorithm --- Algorithm/Clustering/Kmeans.php | 246 ++++++++++++++++++++++ Algorithm/Clustering/Point.php | 120 +++++++++++ Algorithm/Clustering/PointInterface.php | 99 +++++++++ tests/Algorithm/Clustering/KmeansTest.php | 49 +++++ tests/Algorithm/Clustering/PointTest.php | 51 +++++ 5 files changed, 565 insertions(+) create mode 100644 Algorithm/Clustering/Kmeans.php create mode 100644 Algorithm/Clustering/Point.php create mode 100644 Algorithm/Clustering/PointInterface.php create mode 100644 tests/Algorithm/Clustering/KmeansTest.php create mode 100644 tests/Algorithm/Clustering/PointTest.php diff --git a/Algorithm/Clustering/Kmeans.php b/Algorithm/Clustering/Kmeans.php new file mode 100644 index 000000000..b9849ac8f --- /dev/null +++ b/Algorithm/Clustering/Kmeans.php @@ -0,0 +1,246 @@ +points = $points; + $this->clusters = $clusters; + $this->metric = $metric ?? function (PointInterface $a, PointInterface $b) { + $aCoordinates = $a->getCoordinates(); + $bCoordinates = $b->getCoordinates(); + + $n = \count($aCoordinates); + $sum = 0; + + for ($i = 0; $i < $n; ++$i) { + $sum = ($aCoordinates[$i] - $bCoordinates[$i]) * ($aCoordinates[$i] - $bCoordinates[$i]); + } + + return $sum; + }; + + $this->generateClusters($points, $clusters); + } + + /** + * Find the cluster for a point + * + * @param PointInterface $point Point to find the cluster for + * + * @return PointInterface + * + * @since 1.0.0 + */ + public function cluster(PointInterface $point) : PointInterface + { + $bestCluster = null; + $bestDistance = \PHP_FLOAT_MAX; + + foreach ($this->clusterCenters as $center) { + if (($distance = ($this->metric)($center, $point)) < $bestDistance) { + $bestCluster = $center; + $bestDistance = $distance; + } + } + + return $bestCluster; + } + + /** + * Generate the clusters of the points + * + * @param PointInterface[] $points Points to cluster + * @param int $clusters Amount of clusters + * + * @return void + * + * @since 1.0.0 + */ + private function generateClusters(array $points, int $clusters) : void + { + $n = \count($points); + $clusterCenters = $this->kpp($points, $clusters); + $coordinates = \count($points[0]->getCoordinates()); + + while (true) { + foreach ($clusterCenters as $center) { + for ($i = 0; $i < $coordinates; ++$i) { + $center->setCoordinate($i, 0); + } + + $center->setGroup(0); + } + + foreach ($points as $point) { + $clusterPoint = $clusterCenters[$point->getGroup()]; + + $clusterPoint->setGroup( + $clusterPoint->getGroup() + 1 + ); + + for ($i = 0; $i < $coordinates; ++$i) { + $clusterPoint->setCoordinate($i, $clusterPoint->getCoordinate($i) + $point->getCoordinate($i)); + } + } + + foreach ($clusterCenters as $center) { + for ($i = 0; $i < $coordinates; ++$i) { + $center->setCoordinate($i, $center->getCoordinate($i) / $center->getGroup()); + } + } + + $changed = 0; + foreach ($points as $point) { + $min = $this->nearestClusterCenter($point, $clusterCenters)[0]; + + if ($min !== $point->getGroup()) { + ++$changed; + $point->setGroup($min); + } + } + + if ($changed <= $n * 0.001 || $n * 0.001 < 2) { + break; + } + } + + foreach ($clusterCenters as $key => $center) { + $center->setGroup($key); + $center->setName((string) $key); + } + + $this->clusterCenters = $clusterCenters; + } + + /** + * Get the index and distance to the nearest cluster center + * + * @param PointInterface $point Point to get the cluster for + * @param PointInterface[] $clusterCenters All cluster centers + * + * @return array [index, distance] + * + * @since 1.0.0 + */ + private function nearestClusterCenter(PointInterface $point, array $clusterCenters) : array + { + $index = $point->getGroup(); + $dist = \PHP_FLOAT_MAX; + + foreach ($clusterCenters as $key => $cPoint) { + $d = ($this->metric)($cPoint, $point); + + if ($dist > $d) { + $dist = $d; + $index = $key; + } + } + + return [$index, $dist]; + } + + /** + * Initializae cluster centers + * + * @param PointInterface[] $points Points to use for the cluster center initialization + * @param int $n Amount of clusters to use + * + * @return array + * + * @since 1.0.0 + */ + private function kpp(array $points, int $n) : array + { + $clusters = [clone $points[\mt_rand(0, \count($points) - 1)]]; + $d = \array_fill(0, $n, 0.0); + + for ($i = 1; $i < $n; ++$i) { + $sum = 0; + + foreach ($points as $key => $point) { + $d[$key] = $this->nearestClusterCenter($point, \array_slice($clusters, 0, 5))[1]; + $sum += $d[$key]; + } + + $sum *= \mt_rand(0, \mt_getrandmax()) / \mt_getrandmax(); + + foreach ($d as $key => $di) { + $sum -= $di; + + if ($sum <= 0) { + $clusters[$i] = clone $points[$key]; + } + } + } + + foreach ($points as $point) { + $point->setGroup($this->nearestClusterCenter($point, $clusters)[0]); + } + + return $clusters; + } +} \ No newline at end of file diff --git a/Algorithm/Clustering/Point.php b/Algorithm/Clustering/Point.php new file mode 100644 index 000000000..bd1546c9c --- /dev/null +++ b/Algorithm/Clustering/Point.php @@ -0,0 +1,120 @@ +coordinates = $coordinates; + $this->name = $name; + } + + /** + * {@inheritdoc} + */ + public function getCoordinates(): array + { + return $this->coordinates; + } + + /** + * {@inheritdoc} + */ + public function getCoordinate($index) + { + return $this->coordinates[$index]; + } + + /** + * {@inheritdoc} + */ + public function setCoordinate($index, $value) + { + $this->coordinates[$index] = $value; + } + + /** + * {@inheritdoc} + */ + public function getGroup() : int + { + return $this->group; + } + + /** + * {@inheritdoc} + */ + public function setGroup(int $group) : void + { + $this->group = $group; + } + + /** + * {@inheritdoc} + */ + public function setName(string $name) : void + { + $this->name = $name; + } + + /** + * {@inheritdoc} + */ + public function getName() : string + { + return $this->name; + } +} \ No newline at end of file diff --git a/Algorithm/Clustering/PointInterface.php b/Algorithm/Clustering/PointInterface.php new file mode 100644 index 000000000..a14d436ff --- /dev/null +++ b/Algorithm/Clustering/PointInterface.php @@ -0,0 +1,99 @@ +cluster($points[0])->getGroup()); + self::assertEquals(0, $kmeans->cluster($points[1])->getGroup()); + + self::assertEquals(1, $kmeans->cluster($points[2])->getGroup()); + self::assertEquals(1, $kmeans->cluster($points[3])->getGroup()); + self::assertEquals(1, $kmeans->cluster($points[4])->getGroup()); + self::assertEquals(1, $kmeans->cluster($points[5])->getGroup()); + self::assertEquals(1, $kmeans->cluster($points[6])->getGroup()); + } +} \ No newline at end of file diff --git a/tests/Algorithm/Clustering/PointTest.php b/tests/Algorithm/Clustering/PointTest.php new file mode 100644 index 000000000..71b4fad51 --- /dev/null +++ b/tests/Algorithm/Clustering/PointTest.php @@ -0,0 +1,51 @@ +getCoordinates()); + self::assertEquals(3.0, $point->getCoordinate(0)); + self::assertEquals(2.0, $point->getCoordinate(1)); + self::assertEquals(0, $point->getGroup()); + self::assertEquals('abc', $point->getName()); + } + + public function testSetGet() : void + { + $point = new Point([3.0, 2.0], 'abc'); + + $point->setCoordinate(0, 4.0); + $point->setCoordinate(1, 1.0); + + self::assertEquals([4.0, 1.0], $point->getCoordinates()); + self::assertEquals(4.0, $point->getCoordinate(0)); + self::assertEquals(1.0, $point->getCoordinate(1)); + + $point->setGroup(2); + self::assertEquals(2, $point->getGroup()); + } +}