Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
rtavenar committed May 17, 2021
2 parents a3cf3bf + d7b824c commit 4b3b244
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/requirements_rtd.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ tensorflow>=2
Pygments
numba
sphinx_bootstrap_theme
git+git://github.com/numpy/numpydoc@master
git+git://github.com/numpy/numpydoc@main
matplotlib
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[build-system]
requires = ["setuptools", "wheel", "numpy", "Cython"]
requires = ["setuptools", "wheel", "numpy<=1.19", "Cython"]
3 changes: 3 additions & 0 deletions tslearn/metrics/dtw_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ def dtw_path(s1, s2, global_constraint=None, sakoe_chiba_radius=None,
s1 = to_time_series(s1, remove_nans=True)
s2 = to_time_series(s2, remove_nans=True)

if len(s1) == 0 or len(s2) == 0:
raise ValueError("One of the input time series contains only nans or has zero length.")

mask = compute_mask(
s1, s2, GLOBAL_CONSTRAINT_CODE[global_constraint],
sakoe_chiba_radius, itakura_max_slope
Expand Down
15 changes: 15 additions & 0 deletions tslearn/tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
import numpy as np
from scipy.spatial.distance import cdist
import tslearn.metrics
import tslearn.clustering
from tslearn.utils import to_time_series
from tslearn.metrics.dtw_variants import dtw_path

__author__ = 'Romain Tavenard romain.tavenard[at]univ-rennes2.fr'

Expand Down Expand Up @@ -426,3 +428,16 @@ def test_softdtw():

np.testing.assert_equal(dist, dist_ref ** 2)
np.testing.assert_allclose(matrix_path, mat_path_ref)


def test_dtw_path_with_empty_or_nan_inputs():
s1 = np.zeros((3, 10))
s2_empty = np.zeros((0, 10))
with pytest.raises(ValueError) as excinfo:
dtw_path(s1, s2_empty)
assert str(excinfo.value) == "One of the input time series contains only nans or has zero length."

s2_nan = np.full((3, 10), np.nan)
with pytest.raises(ValueError) as excinfo:
dtw_path(s1, s2_nan)
assert str(excinfo.value) == "One of the input time series contains only nans or has zero length."

0 comments on commit 4b3b244

Please sign in to comment.