From 4bffa95db27ed8d9f08532ffa04fbe6a5abd0189 Mon Sep 17 00:00:00 2001 From: EvaSnow <10635420+evasnow1992@users.noreply.github.com> Date: Thu, 10 Oct 2024 21:44:03 -0700 Subject: [PATCH] Optimize Butina Clustering for Performance and Expand API (#7892) * Improve Butina clustering performance and readability with NumPy optimizations * Code cleanup, improved input data type flexibility, and corrected test expectation In Butina.py: - Cleaned up stale subversion comment lines - Removed type annotation - Allowed input data as tuple - Updated input data type check if condition to increase readability In UnitTestButina.py: - Cleaned up stale subversion comment lines - Updated test10 invalid input data type check from expecting ValueError to TypeError --- rdkit/ML/Cluster/Butina.py | 153 +++++++++++++++++------------ rdkit/ML/Cluster/UnitTestButina.py | 21 +++- 2 files changed, 108 insertions(+), 66 deletions(-) diff --git a/rdkit/ML/Cluster/Butina.py b/rdkit/ML/Cluster/Butina.py index b2fd2f463e1..232e20c72ca 100644 --- a/rdkit/ML/Cluster/Butina.py +++ b/rdkit/ML/Cluster/Butina.py @@ -1,30 +1,31 @@ -# $Id$ -# -# Copyright (C) 2007-2008 Greg Landrum +# Copyright (C) 2007-2024 Greg Landrum and other RDKit contributors # All Rights Reserved # """ Implementation of the clustering algorithm published in: Butina JCICS 39 747-750 (1999) """ -import numpy - -from rdkit import RDLogger - -logger = RDLogger.logger() - +import numpy as np def EuclideanDist(pi, pj): - dv = numpy.array(pi) - numpy.array(pj) - return numpy.sqrt(dv * dv) - + """Calculate the Euclidean distance between two points.""" + pi, pj = np.asarray(pi), np.asarray(pj) + return np.sqrt(np.sum((pi - pj) ** 2)) + +def compute_distance_matrix(data, n_pts, dist_func): + """Compute the distance matrix for the given data.""" + dist_matrix = np.zeros((n_pts, n_pts)) + for i in range(n_pts): + for j in range(i): + dist_matrix[i, j] = dist_matrix[j, i] = dist_func(data[i], data[j]) + return dist_matrix def ClusterData(data, nPts, distThresh, isDistData=False, distFunc=EuclideanDist, reordering=False): """ clusters the data points passed in and returns the list of clusters **Arguments** - - data: a list of items with the input data + - data: a list, tuple, or numpy array of items with the input data (see discussion of _isDistData_ argument for the exception) - nPts: the number of points to be used @@ -33,8 +34,10 @@ def ClusterData(data, nPts, distThresh, isDistData=False, distFunc=EuclideanDist to be neighbors - isDistData: set this toggle when the data passed in is a - distance matrix. The distance matrix should be stored - symmetrically. An example of how to do this: + distance matrix. The distance matrix should be stored + in one of two formats: as an nxn NumPy array, or as a + symmetrically stored list or 1D array generated using a + similar process to the example below: dists = [] for i in range(nPts): @@ -59,57 +62,83 @@ def ClusterData(data, nPts, distThresh, isDistData=False, distFunc=EuclideanDist The first element for each cluster is its centroid. """ - if isDistData and len(data) > (nPts * (nPts - 1) / 2): - logger.warning("Distance matrix is too long") - nbrLists = [None] * nPts - for i in range(nPts): - nbrLists[i] = [] - - dmIdx = 0 - for i in range(nPts): - for j in range(i): - if not isDistData: - dij = distFunc(data[i], data[j]) - else: - dij = data[dmIdx] - dmIdx += 1 - if dij <= distThresh: - nbrLists[i].append(j) - nbrLists[j].append(i) - # sort by the number of neighbors: - tLists = [(len(y), x) for x, y in enumerate(nbrLists)] - tLists.sort(reverse=True) - - res = [] - seen = [0] * nPts - while tLists: - _, idx = tLists.pop(0) + if isDistData: + # Check if data is a supported type + if not isinstance(data, (list, tuple, np.ndarray)): + raise TypeError(f"Unsupported type for data, {type(data)}") + + # Check if data is a 1D array or list + if isinstance(data, (list, tuple)) or (isinstance(data, np.ndarray) and data.ndim == 1): + # Check if data length matches the required number of points + if len(data) != (nPts * (nPts - 1)) // 2: + raise ValueError("Mismatched input data dimension and nPts") + + # Create a distance matrix from the 1D data + dist_matrix = np.zeros((nPts, nPts)) + idx = np.tril_indices(nPts, -1) + dist_matrix[idx] = data + dist_matrix += dist_matrix.T + else: + # Check if data is a matrix of the correct shape and use it as distance matrix + if data.shape != (nPts, nPts): + raise ValueError(f"Input data with shape {data.shape} is not a matrix of the required shape {(nPts, nPts)}") + dist_matrix = data + else: + # Compute distance matrix from the data points + dist_matrix = compute_distance_matrix(data, nPts, distFunc) + + # Initialize neighbor lists + neighbor_lists = [np.where(dist_matrix[i] <= distThresh)[0].tolist() for i in range(nPts)] + + # Sort points by the number of neighbors in descending order + sorted_indices = [(len(neighbors), idx) for idx, neighbors in enumerate(neighbor_lists)] + sorted_indices.sort(reverse=True) + + # Initialize clusters and a seen array to keep track of processed points + clusters = [] + seen = np.zeros(nPts, dtype=bool) + + # Process all candidate clusters that have at least two members + while sorted_indices and sorted_indices[0][0] > 1: + _, idx = sorted_indices.pop(0) if seen[idx]: continue - tRes = [idx] - for nbr in nbrLists[idx]: - if not seen[nbr]: - tRes.append(nbr) - seen[nbr] = 1 - # update the number of neighbors: + + # Create a new cluster and mark points as seen + cluster = [idx] + seen[idx] = True + for neighbor in neighbor_lists[idx]: + if not seen[neighbor]: + cluster.append(neighbor) + seen[neighbor] = True + + clusters.append(tuple(cluster)) + + # Update the number of neighbors: # remove all members of the new cluster from the list of - # neighbors and reorder the tLists + # neighbors and reorder the sorted_indices if reordering: - # get the list of affected molecules, i.e. all molecules + # Get the set of unassigned and affected molecules, i.e. all unseen molecules # which have at least one of the members of the new cluster # as a neighbor - nbrNbr = [nbrLists[t] for t in tRes] - nbrNbr = frozenset().union(*nbrNbr) - # loop over all remaining molecules in tLists but only + affected = set(neighbor for point in cluster for neighbor in neighbor_lists[point] if not seen[neighbor]) + + # Loop over all remaining molecules in sorted_indices but only # consider unassigned and affected compounds - for x, y in enumerate(tLists): - y1 = y[1] - if seen[y1] or (y1 not in nbrNbr): - continue - # update the number of neighbors - nbrLists[y1] = set(nbrLists[y1]).difference(tRes) - tLists[x] = (len(nbrLists[y1]), y1) - # now reorder the list - tLists.sort(reverse=True) - res.append(tuple(tRes)) - return tuple(res) + for ii, element in enumerate(sorted_indices): + affected_point = element[1] + if affected_point in affected: + # Update the number of neighbors + new_neighbors = [nbr for nbr in neighbor_lists[affected_point] if not seen[nbr]] + neighbor_lists[affected_point] = new_neighbors + sorted_indices[ii] = (len(new_neighbors), affected_point) + # Reorder the list + sorted_indices.sort(reverse=True) + + # Process any remaining single-point clusters + while sorted_indices: + _, idx = sorted_indices.pop(0) + if seen[idx]: + continue + clusters.append(tuple([idx])) + return tuple(clusters) diff --git a/rdkit/ML/Cluster/UnitTestButina.py b/rdkit/ML/Cluster/UnitTestButina.py index 5c144a1c6e3..b9e2f1ae3a6 100755 --- a/rdkit/ML/Cluster/UnitTestButina.py +++ b/rdkit/ML/Cluster/UnitTestButina.py @@ -1,6 +1,4 @@ -# $Id$ -# -# Copyright (C) 2007-2008 Greg Landrum +# Copyright (C) 2007-2024 Greg Landrum and other RDKit contributors # All Rights Reserved # import unittest @@ -181,7 +179,22 @@ def test8_reordering_changes(self): self.assertTrue(len(cs) == 2) self.assertTrue(cs[0] == (4, 3, 5, 6)) self.assertTrue(cs[1] == (1, 0, 2)) - + + def test9_empty_input(self): + # " edge case: empty input " + cs = Butina.ClusterData([], 0, 2, isDistData=1, reordering=True) + self.assertTrue(len(cs) == 0) + + def test10_error_messages(self): + # " error case: mismatched data dimension and nPts " + with self.assertRaises(ValueError) as cm: + _ = Butina.ClusterData([1, 2, 3], 1, 2, isDistData=1, reordering=True) + self.assertEqual(str(cm.exception), "Mismatched input data dimension and nPts") + + # " error case: invalid input data type " + with self.assertRaises(TypeError) as cm: + _ = Butina.ClusterData(0, 1, 2, isDistData=1, reordering=True) + self.assertEqual(str(cm.exception), f"Unsupported type for data, {type(0)}") profileTest = 0