Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clustering - DBSCAN #86

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
115 changes: 115 additions & 0 deletions src/main/scala/io/picnicml/doddlemodel/cluster/DBSCAN.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package io.picnicml.doddlemodel.cluster

import breeze.linalg.DenseVector
import breeze.linalg.functions.euclideanDistance
import cats.syntax.option._
import io.picnicml.doddlemodel.data.Features
import io.picnicml.doddlemodel.syntax.OptionSyntax._
import io.picnicml.doddlemodel.typeclasses.Clusterer

import scala.collection.mutable

/** An immutable DBSCAN clustering model.
*
* @param eps: the maximum distance between two datapoints to be considered in a common neighborhood
* @param minSamples: the minimum number of datapoints in a neighborhood for a point to be considered the core point
*/
case class DBSCAN private (eps: Double, minSamples: Int, private val labels: Option[DenseVector[Double]])

object DBSCAN {

private type Neighbors = mutable.Queue[Int]

private val UNASSIGNED = Double.MaxValue
private val NOISE = -1.0

def apply(eps: Double = 0.5, minSamples: Int = 5): DBSCAN = {
require(eps > 0.0, "Maximum distance eps needs to be larger than 0")
require(minSamples > 0, "Minimum number of samples needs to be larger than 0")
DBSCAN(eps, minSamples, none)
}

implicit lazy val ev: Clusterer[DBSCAN] = new Clusterer[DBSCAN] {

override protected def copy(model: DBSCAN): DBSCAN = model.copy()

override def isFitted(model: DBSCAN): Boolean = model.labels.isDefined

override protected def fitSafe(model: DBSCAN, x: Features): DBSCAN = {
val nn = NearestNeighbors(x)
val finalState = (0 until x.rows).foldLeft(State.initial(x.rows)) { case (state, rowIdx) =>
if (state.labels(rowIdx) == UNASSIGNED)
handleUnassignedPoint(model, rowIdx, nn, state)
else
state
}
model.copy(labels = finalState.labels.some)
}

private def handleUnassignedPoint(model: DBSCAN, rowIdx: Int, nn: NearestNeighbors, s: State): State = {
val neighbors = nn.getNeighbors(rowIdx, model.eps)
if (neighbors.length + 1 < model.minSamples) {
s.labels(rowIdx) = NOISE
s
}
else
expandPoint(model, rowIdx, mutable.Queue(neighbors:_*), nn, s.copy(clusterId = s.clusterId + 1))
}

private def expandPoint(model: DBSCAN, rowIdx: Int, neighbors: Neighbors, nn: NearestNeighbors, s: State): State = {
s.labels(rowIdx) = s.clusterId.toDouble
while (neighbors.nonEmpty) {
val neighbor = neighbors.dequeue
if (s.labels(neighbor) == NOISE)
s.labels(neighbor) = s.clusterId.toDouble
else if (s.labels(neighbor) == UNASSIGNED)
neighbors.enqueueAll(processUnassignedNeighbor(model, neighbor, nn, s))
}
s
}

private def processUnassignedNeighbor(model: DBSCAN, neighbor: Int, nn: NearestNeighbors, s: State): Neighbors = {
s.labels(neighbor) = s.clusterId.toDouble
val neighborNeighbors = nn.getNeighbors(neighbor, model.eps)
if (neighborNeighbors.length + 1 >= model.minSamples)
mutable.Queue(neighborNeighbors:_*)
else
mutable.Queue.empty
}

override protected def labelsSafe(model: DBSCAN): DenseVector[Double] = model.labels.getOrBreak

private case class State (labels: DenseVector[Double], clusterId: Int)
private object State {
def initial(numPoints: Int): State = State(DenseVector.fill[Double](numPoints)(UNASSIGNED), -1)
}
}

// todo: implement NearestNeighbors with a kd-tree / ball-tree data structure and move into a separate file
// distanceMatrix is a mapping from pairs of points indices (row indices in x) to their distances
private class NearestNeighbors (val distanceMatrix: mutable.AnyRefMap[(Int, Int), Double], val numPoints: Int) {

def getNeighbors(rowIdx: Int, eps: Double): Seq[Int] = {
(0 until numPoints).filter { candidateIndex =>
candidateIndex != rowIdx && getDistance(candidateIndex, rowIdx) <= eps
}
}

def getDistance(rowIdx0: Int, rowIdx1: Int): Double = {
if (rowIdx0 > rowIdx1)
distanceMatrix((rowIdx1, rowIdx0))
else
distanceMatrix((rowIdx0, rowIdx1))
}
}

private object NearestNeighbors {
def apply(x: Features): NearestNeighbors = {
val distanceMatrix = mutable.AnyRefMap[(Int, Int), Double]()
(0 until x.rows).combinations(2).foreach { case rowIdx0 +: rowIdx1 +: IndexedSeq() =>
distanceMatrix((rowIdx0, rowIdx1)) = euclideanDistance(x(rowIdx0, ::).t, x(rowIdx1, ::).t)
}
new NearestNeighbors(distanceMatrix, x.rows)
}
}
}
31 changes: 31 additions & 0 deletions src/main/scala/io/picnicml/doddlemodel/typeclasses/Clusterer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package io.picnicml.doddlemodel.typeclasses

import breeze.linalg.DenseVector
import io.picnicml.doddlemodel.data.Features

trait Clusterer[A] extends Estimator[A] {

def fit(model: A, x: Features): A = {
require(!isFitted(model), "Called fit on a model that is already fitted")
fitSafe(copy(model), x)
}

def fitPredict(model: A, x: Features): DenseVector[Double] = {
require(!isFitted(model), "Called fit on a model that is already fitted")
labelsSafe(fitSafe(copy(model), x))
}

/** A function that creates an identical clusterer. */
protected def copy(model: A): A

/** A function that is guaranteed to be called on a non-fitted model. **/
protected def fitSafe(model: A, x: Features): A

def labels(model: A): DenseVector[Double] = {
require(isFitted(model), "Called labels on a model that is not fitted yet")
labelsSafe(model)
}

/** A function that is guaranteed to be called on a fitted model. */
protected def labelsSafe(model: A): DenseVector[Double]
}
70 changes: 70 additions & 0 deletions src/test/scala/io/picnicml/doddlemodel/cluster/DBSCANTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package io.picnicml.doddlemodel.cluster

import breeze.linalg.{DenseMatrix, DenseVector}
import io.picnicml.doddlemodel.TestingUtils
import io.picnicml.doddlemodel.cluster.DBSCAN.ev
import org.scalactic.{Equality, TolerantNumerics}
import org.scalatest.{FlatSpec, Matchers}

class DBSCANTest extends FlatSpec with Matchers with TestingUtils {

implicit val doubleTolerance: Equality[Double] = TolerantNumerics.tolerantDoubleEquality(1e-4)

private val x = DenseMatrix(
List(1.0, 1.0),
List(0.0, 2.0),
List(2.0, 0.0),
List(8.0, 1.0),
List(7.0, 2.0),
List(9.0, 0.0)
)

"DBSCAN" should "cluster the datapoints" in {
val model = DBSCAN(eps = 3.0, minSamples = 1)
breezeEqual(ev.fitPredict(model, x), DenseVector(0.0, 0.0, 0.0, 1.0, 1.0, 1.0)) shouldEqual true
}

it should "cluster each datapoint into it's own group when eps is too small" in {
val model = DBSCAN()
breezeEqual(ev.fitPredict(model, x), DenseVector(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) shouldEqual true
}

it should "cluster all data points into a single group when eps is too large" in {
val model = DBSCAN(eps = 10.0)
breezeEqual(ev.fitPredict(model, x), DenseVector(0.0, 0.0, 0.0, 0.0, 0.0, 0.0)) shouldEqual true
}

it should "label all points as outliers when min samples is too large" in {
val model = DBSCAN(minSamples = 7)
breezeEqual(ev.fitPredict(model, x), DenseVector(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0)) shouldEqual true
}

it should "cluster all datapoints into a single group when eps equals the distance between points" in {
val smallX = DenseMatrix(
List(0.0, 0.0),
List(3.0, 0.0)
)
val model = DBSCAN(eps = 3.0)
breezeEqual(ev.fitPredict(model, smallX), DenseVector(0.0, 0.0)) shouldEqual true
}

it should "cluster all datapoints into a single group" in {
val d1X = DenseMatrix(
List(0.0, 12.0),
List(0.0, 9.0),
List(0.0, 6.0),
List(0.0, 3.0),
List(0.0, 0.0)
)
val model = DBSCAN(eps = 3.0, minSamples = 3)
breezeEqual(ev.fitPredict(model, d1X), DenseVector(0.0, 0.0, 0.0, 0.0, 0.0)) shouldEqual true
}

it should "prevent the usage of negative eps" in {
an [IllegalArgumentException] shouldBe thrownBy(DBSCAN(eps = -0.5))
}

it should "prevent the usage of negative min samples" in {
an [IllegalArgumentException] shouldBe thrownBy(DBSCAN(minSamples = -1))
}
}