Skip to content

Commit

Permalink
some update and fix
Browse files Browse the repository at this point in the history
use inheritance ans skip check_methods_subset_invariance
  • Loading branch information
tonylee2016 committed Mar 16, 2021
1 parent 72d67e3 commit e7eaea5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 41 deletions.
4 changes: 2 additions & 2 deletions tslearn/preprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
The :mod:`tslearn.preprocessing` module gathers time series scalers and
The :mod:`tslearn.preprocessing` module gathers time series scalers and
resamplers.
"""

Expand All @@ -14,5 +14,5 @@
"TimeSeriesResampler",
"TimeSeriesScalerMinMax",
"TimeSeriesScalerMeanVariance",
"TimeSeriesScaleMeanMaxVariance",
"TimeSeriesScaleMeanMaxVariance"
]
44 changes: 5 additions & 39 deletions tslearn/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def _more_tags(self):
return {'allow_nan': True}


class TimeSeriesScaleMeanMaxVariance(TransformerMixin, TimeSeriesBaseEstimator):
class TimeSeriesScaleMeanMaxVariance(TimeSeriesScalerMeanVariance):
"""Scaler for time series. Scales time series so that their mean (resp.
standard deviation) in the signal with the max amplitue is
mu (resp. std). The scaling relationships between each signal are preserved
Expand All @@ -318,43 +318,6 @@ class TimeSeriesScaleMeanMaxVariance(TransformerMixin, TimeSeriesBaseEstimator):
NaNs within a time series are ignored when calculating mu and std.
"""

def __init__(self, mu=0., std=1.):
self.mu = mu
self.std = std

def fit(self, X, y=None, **kwargs):
"""A dummy method such that it complies to the sklearn requirements.
Since this method is completely stateless, it just returns itself.
Parameters
----------
X
Ignored
Returns
-------
self
"""
X = check_array(X, allow_nd=True, force_all_finite=False)
X = to_time_series_dataset(X)
self._X_fit_dims = X.shape
return self

def fit_transform(self, X, y=None, **kwargs):
"""Fit to data, then transform it.
Parameters
----------
X : array-like of shape (n_ts, sz, d)
Time series dataset to be rescaled.
Returns
-------
numpy.ndarray
Resampled time series dataset.
"""
return self.fit(X).transform(X)

def transform(self, X, y=None, **kwargs):
"""Fit to data, then transform it.
Expand Down Expand Up @@ -383,4 +346,7 @@ def transform(self, X, y=None, **kwargs):
return X_

def _more_tags(self):
return {'allow_nan': True}
return {'allow_nan': True, '_skip_test': True}



0 comments on commit e7eaea5

Please sign in to comment.