-
Notifications
You must be signed in to change notification settings - Fork 0
/
DBSCAN.py
100 lines (80 loc) · 2.74 KB
/
DBSCAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from scipy.spatial import cKDTree
class DBSCAN:
def __init__(self, min_pts = 4, distance=0.1, protocol=0):
self.q = set()
self.memo = {}
self.visited = set()
self.min_pts = min_pts
self.distance = distance
self.clusters = []
self.points = []
self.tree = None
self.clusters_used = set()
def range_query(self, i, eps, min, cluster):
if i in self.visited:
return
self.visited.add(i)
point = self.points[i]
neighbors = self.tree.query_ball_point(x=point, r=eps, n_jobs=-1)
if len(neighbors) < min and cluster not in self.clusters_used:
# else:
self.clusters[i] = 0
else:
self.clusters[i] = cluster
self.clusters_used.add(cluster)
if len(neighbors) >= min:
for p in neighbors:
if p not in self.q:
self.q.add(p)
self.range_query(p, eps, min, cluster)
def dbscan(self, eps, min):
self.clusters = [0] * len(self.points)
self.memo = {}
self.clusters_used = set()
self.visited = set()
cluster = 1
for i in range(len(self.points)):
if i in self.visited:
continue
if cluster in self.clusters_used:
cluster += 1
self.range_query(i, eps, min, cluster)
def fit(self, points):
self.points = points
self.tree = cKDTree(points)
self.clusters = [0 * len(self.points)]
self.memo = {}
self.clusters_used = set()
self.visited = set()
self.dbscan(eps = self.distance, min = self.min_pts)
def predict(self):
return self.clusters
if __name__ == "__main__":
import numpy as np
import plotly.express as px
import plotly
from pathlib import Path
np.random.seed(0)
points = np.random.random((1500, 2))
classifier = DBSCAN(6, 0.038)
classifier.fit(points)
results = classifier.predict()
fig = px.scatter(x=points[:, 0], y=points[:, 1], color=[str(i) for i in classifier.clusters])
fig.show()
fig.update_layout(
height = 600
)
path = Path("plot")
plotly.offline.plot(fig, filename=str(path))
from sklearn.cluster import DBSCAN as DBSCAN2
classifier2 = DBSCAN2(eps=0.038, min_samples= 6, algorithm="kd_tree")
results2 = classifier2.fit_predict(points)
fig2 = px.scatter(x=points[:, 0], y=points[:, 1], color=[str(i) for i in results2 ])
# fig.show()
fig2.update_layout(
height = 600
)
path = Path("plot2")
plotly.offline.plot(fig2, filename=str(path))
fig.write_json("fig1.json", pretty = True)
fig2.write_json("fig2.json")