Skip to content

Commit

Permalink
Optimize Butina Clustering for Performance and Expand API (rdkit#7892)
Browse files Browse the repository at this point in the history
* Improve Butina clustering performance and readability with NumPy optimizations

* Code cleanup, improved input data type flexibility, and corrected test expectation

In Butina.py:
-  Cleaned up stale subversion comment lines
-  Removed type annotation
-  Allowed input data as tuple
-  Updated input data type check if condition to increase readability

In UnitTestButina.py:
- Cleaned up stale subversion comment lines
- Updated test10 invalid input data type check from expecting ValueError to TypeError
  • Loading branch information
evasnow1992 authored Oct 11, 2024
1 parent 0909d75 commit 4bffa95
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 66 deletions.
153 changes: 91 additions & 62 deletions rdkit/ML/Cluster/Butina.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,31 @@
# $Id$
#
# Copyright (C) 2007-2008 Greg Landrum
# Copyright (C) 2007-2024 Greg Landrum and other RDKit contributors
# All Rights Reserved
#
""" Implementation of the clustering algorithm published in:
Butina JCICS 39 747-750 (1999)
"""
import numpy

from rdkit import RDLogger

logger = RDLogger.logger()

import numpy as np

def EuclideanDist(pi, pj):
dv = numpy.array(pi) - numpy.array(pj)
return numpy.sqrt(dv * dv)

"""Calculate the Euclidean distance between two points."""
pi, pj = np.asarray(pi), np.asarray(pj)
return np.sqrt(np.sum((pi - pj) ** 2))

def compute_distance_matrix(data, n_pts, dist_func):
"""Compute the distance matrix for the given data."""
dist_matrix = np.zeros((n_pts, n_pts))
for i in range(n_pts):
for j in range(i):
dist_matrix[i, j] = dist_matrix[j, i] = dist_func(data[i], data[j])
return dist_matrix

def ClusterData(data, nPts, distThresh, isDistData=False, distFunc=EuclideanDist, reordering=False):
""" clusters the data points passed in and returns the list of clusters
**Arguments**
- data: a list of items with the input data
- data: a list, tuple, or numpy array of items with the input data
(see discussion of _isDistData_ argument for the exception)
- nPts: the number of points to be used
Expand All @@ -33,8 +34,10 @@ def ClusterData(data, nPts, distThresh, isDistData=False, distFunc=EuclideanDist
to be neighbors
- isDistData: set this toggle when the data passed in is a
distance matrix. The distance matrix should be stored
symmetrically. An example of how to do this:
distance matrix. The distance matrix should be stored
in one of two formats: as an nxn NumPy array, or as a
symmetrically stored list or 1D array generated using a
similar process to the example below:
dists = []
for i in range(nPts):
Expand All @@ -59,57 +62,83 @@ def ClusterData(data, nPts, distThresh, isDistData=False, distFunc=EuclideanDist
The first element for each cluster is its centroid.
"""
if isDistData and len(data) > (nPts * (nPts - 1) / 2):
logger.warning("Distance matrix is too long")
nbrLists = [None] * nPts
for i in range(nPts):
nbrLists[i] = []

dmIdx = 0
for i in range(nPts):
for j in range(i):
if not isDistData:
dij = distFunc(data[i], data[j])
else:
dij = data[dmIdx]
dmIdx += 1
if dij <= distThresh:
nbrLists[i].append(j)
nbrLists[j].append(i)
# sort by the number of neighbors:
tLists = [(len(y), x) for x, y in enumerate(nbrLists)]
tLists.sort(reverse=True)

res = []
seen = [0] * nPts
while tLists:
_, idx = tLists.pop(0)
if isDistData:
# Check if data is a supported type
if not isinstance(data, (list, tuple, np.ndarray)):
raise TypeError(f"Unsupported type for data, {type(data)}")

# Check if data is a 1D array or list
if isinstance(data, (list, tuple)) or (isinstance(data, np.ndarray) and data.ndim == 1):
# Check if data length matches the required number of points
if len(data) != (nPts * (nPts - 1)) // 2:
raise ValueError("Mismatched input data dimension and nPts")

# Create a distance matrix from the 1D data
dist_matrix = np.zeros((nPts, nPts))
idx = np.tril_indices(nPts, -1)
dist_matrix[idx] = data
dist_matrix += dist_matrix.T
else:
# Check if data is a matrix of the correct shape and use it as distance matrix
if data.shape != (nPts, nPts):
raise ValueError(f"Input data with shape {data.shape} is not a matrix of the required shape {(nPts, nPts)}")
dist_matrix = data
else:
# Compute distance matrix from the data points
dist_matrix = compute_distance_matrix(data, nPts, distFunc)

# Initialize neighbor lists
neighbor_lists = [np.where(dist_matrix[i] <= distThresh)[0].tolist() for i in range(nPts)]

# Sort points by the number of neighbors in descending order
sorted_indices = [(len(neighbors), idx) for idx, neighbors in enumerate(neighbor_lists)]
sorted_indices.sort(reverse=True)

# Initialize clusters and a seen array to keep track of processed points
clusters = []
seen = np.zeros(nPts, dtype=bool)

# Process all candidate clusters that have at least two members
while sorted_indices and sorted_indices[0][0] > 1:
_, idx = sorted_indices.pop(0)
if seen[idx]:
continue
tRes = [idx]
for nbr in nbrLists[idx]:
if not seen[nbr]:
tRes.append(nbr)
seen[nbr] = 1
# update the number of neighbors:

# Create a new cluster and mark points as seen
cluster = [idx]
seen[idx] = True
for neighbor in neighbor_lists[idx]:
if not seen[neighbor]:
cluster.append(neighbor)
seen[neighbor] = True

clusters.append(tuple(cluster))

# Update the number of neighbors:
# remove all members of the new cluster from the list of
# neighbors and reorder the tLists
# neighbors and reorder the sorted_indices
if reordering:
# get the list of affected molecules, i.e. all molecules
# Get the set of unassigned and affected molecules, i.e. all unseen molecules
# which have at least one of the members of the new cluster
# as a neighbor
nbrNbr = [nbrLists[t] for t in tRes]
nbrNbr = frozenset().union(*nbrNbr)
# loop over all remaining molecules in tLists but only
affected = set(neighbor for point in cluster for neighbor in neighbor_lists[point] if not seen[neighbor])

# Loop over all remaining molecules in sorted_indices but only
# consider unassigned and affected compounds
for x, y in enumerate(tLists):
y1 = y[1]
if seen[y1] or (y1 not in nbrNbr):
continue
# update the number of neighbors
nbrLists[y1] = set(nbrLists[y1]).difference(tRes)
tLists[x] = (len(nbrLists[y1]), y1)
# now reorder the list
tLists.sort(reverse=True)
res.append(tuple(tRes))
return tuple(res)
for ii, element in enumerate(sorted_indices):
affected_point = element[1]
if affected_point in affected:
# Update the number of neighbors
new_neighbors = [nbr for nbr in neighbor_lists[affected_point] if not seen[nbr]]
neighbor_lists[affected_point] = new_neighbors
sorted_indices[ii] = (len(new_neighbors), affected_point)
# Reorder the list
sorted_indices.sort(reverse=True)

# Process any remaining single-point clusters
while sorted_indices:
_, idx = sorted_indices.pop(0)
if seen[idx]:
continue
clusters.append(tuple([idx]))
return tuple(clusters)
21 changes: 17 additions & 4 deletions rdkit/ML/Cluster/UnitTestButina.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# $Id$
#
# Copyright (C) 2007-2008 Greg Landrum
# Copyright (C) 2007-2024 Greg Landrum and other RDKit contributors
# All Rights Reserved
#
import unittest
Expand Down Expand Up @@ -181,7 +179,22 @@ def test8_reordering_changes(self):
self.assertTrue(len(cs) == 2)
self.assertTrue(cs[0] == (4, 3, 5, 6))
self.assertTrue(cs[1] == (1, 0, 2))


def test9_empty_input(self):
# " edge case: empty input "
cs = Butina.ClusterData([], 0, 2, isDistData=1, reordering=True)
self.assertTrue(len(cs) == 0)

def test10_error_messages(self):
# " error case: mismatched data dimension and nPts "
with self.assertRaises(ValueError) as cm:
_ = Butina.ClusterData([1, 2, 3], 1, 2, isDistData=1, reordering=True)
self.assertEqual(str(cm.exception), "Mismatched input data dimension and nPts")

# " error case: invalid input data type "
with self.assertRaises(TypeError) as cm:
_ = Butina.ClusterData(0, 1, 2, isDistData=1, reordering=True)
self.assertEqual(str(cm.exception), f"Unsupported type for data, {type(0)}")

profileTest = 0

Expand Down

0 comments on commit 4bffa95

Please sign in to comment.