diff --git a/mat_discover/mat_discover_.py b/mat_discover/mat_discover_.py index 4f0651e..7310503 100644 --- a/mat_discover/mat_discover_.py +++ b/mat_discover/mat_discover_.py @@ -236,6 +236,7 @@ def __init__( plotting: bool = False, pdf: bool = True, n_peak_neighbors: int = 10, + radius=None, verbose: bool = True, dummy_run: bool = False, Scaler=RobustScaler, @@ -390,6 +391,7 @@ def __init__( self.plotting = plotting self.pdf = pdf self.n_peak_neighbors = n_peak_neighbors + self.radius = radius self.verbose = verbose self.dummy_run = dummy_run if dummy_run: @@ -768,7 +770,7 @@ def predict( # compound-wise scores (i.e. individual compounds) with self.Timer("nearest-neighbor-properties"): self.rad_neigh_avg_targ, self.k_neigh_avg_targ = nearest_neigh_props( - self.dm, pred, n_neighbors=self.n_peak_neighbors + self.dm, pred, n_neighbors=self.n_peak_neighbors, radius=self.radius ) self.val_rad_neigh_avg = self.rad_neigh_avg_targ[val_ids] self.val_k_neigh_avg = self.k_neigh_avg_targ[val_ids]