Skip to content

Commit

Permalink
Remove use of dask.array.ma in PCA in favour of array API compliant f…
Browse files Browse the repository at this point in the history
…unctions
  • Loading branch information
tomwhite committed Oct 14, 2024
1 parent 161b0c7 commit 9dd940e
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions sgkit/stats/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def fit(
Alternate allele counts with missing values encoded as either nan
or negative numbers.
"""
X = da.ma.masked_array(X, mask=da.isnan(X) | (X < 0))
self.mean_ = da.ma.filled(da.mean(X, axis=0), fill_value=np.nan)
X = _replace_missing_with_nan(X)
self.mean_ = da.nanmean(X, axis=0)
p = self.mean_ / self.ploidy
self.scale_ = da.sqrt(p * (1 - p))
self.n_features_in_ = X.shape[1]
Expand All @@ -90,10 +90,10 @@ def transform(
Alternate allele counts with missing values encoded as either nan
or negative numbers.
"""
X = da.ma.masked_array(X, mask=da.isnan(X) | (X < 0))
X = _replace_missing_with_nan(X)
X -= self.mean_
X /= self.scale_
return da.ma.filled(X, fill_value=np.nan)
return X

def inverse_transform(self, X: ArrayLike, copy: Optional[bool] = None) -> ArrayLike:
"""Invert transform
Expand All @@ -109,6 +109,14 @@ def inverse_transform(self, X: ArrayLike, copy: Optional[bool] = None) -> ArrayL
return X


def _replace_missing_with_nan(X):
if np.issubdtype(X.dtype, np.floating):
nanarray = da.asarray(np.nan, dtype=X.dtype)
else:
nanarray = da.asarray(np.nan)
return da.where(X < 0, nanarray, X)


def filter_partial_calls(
ds: Dataset,
*,
Expand Down

0 comments on commit 9dd940e

Please sign in to comment.