Skip to content

Commit

Permalink
update closest kmeans func
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Nov 2, 2024
1 parent 7007dc0 commit e3d7960
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions imodels/clustering/stableclustering.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
import numpy as np
from sklearn.base import BaseEstimator, ClusterMixin
from sklearn.cluster import KMeans
Expand Down Expand Up @@ -72,6 +73,25 @@ def predict(self, X):
# elif self.algorithm == "nmf":
# return np.argmax(self.best_model.transform(X), axis=1)

def predict_closest_points_to_centroids(self, X, n_closest=1):
'''
Returns predicted cluster index for each point in X.
For the n_closest points of each cluster to each centroid, returns the cluster index.
Returns -1 for all other points.
'''
check_is_fitted(self, ["best_model_", "best_k_"])
distances = self.best_model_.transform(
X) # Shape: (n_samples, n_clusters)
cluster_assignments = self.predict(X)

preds = np.full(X.shape[0], -1)
for i in range(self.best_k_):
distances_ = deepcopy(distances[:, i])
distances_[cluster_assignments != i] = np.inf
closest_points = np.argsort(distances_)[:n_closest]
preds[closest_points] = i
return preds


if __name__ == '__main__':
# sample sklearn datraset
Expand Down

0 comments on commit e3d7960

Please sign in to comment.