Skip to content

Commit

Permalink
Move input validation to distance method
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrejaKovacic committed Oct 25, 2019
1 parent 93b0494 commit 3e6b549
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 12 deletions.
22 changes: 19 additions & 3 deletions Orange/distance/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from scipy import stats
from scipy import sparse as sp
import sklearn.metrics as skl_metrics
from sklearn.utils import check_array
from sklearn.utils.extmath import row_norms, safe_sparse_dot
from sklearn.metrics import pairwise_distances

Expand Down Expand Up @@ -645,16 +646,31 @@ def fit(self, _):
return PearsonModel(True, self.axis, self.impute)

def _prob_dist(a):
# Makes the vector sum to one, as to mimick probability distribution.
# Makes the vector sum to one, as to mimic probability distribution.
return a / np.sum(a)

def non_negative(a):
#Raise an exception for infinities, nans and negative values
try:
check_array(a, accept_sparse=True, accept_large_sparse=True, ensure_2d=False)
except:
raise ValueError("Bhattcharyya distance requires non-negative values")
if sp.issparse(a):
if a.min() < 0:
raise ValueError("Bhattcharyya distance requires non-negative values")
return
if min(a) < 0:
raise ValueError("Bhattcharyya distance requires non-negative values")

def _bhattacharyya(a, b):
# not a real metric, does not obey triangle inequality
non_negative(a)
non_negative(b)
a = _prob_dist(a)
b = _prob_dist(b)
if sp.issparse(a):
return - np.log(np.sum(np.sqrt(a.multiply(b))))
return - np.log(np.sum(np.sqrt(a * b)))
return -np.log(np.sum(np.sqrt(a.multiply(b))))
return -np.log(np.sum(np.sqrt(a * b)))

class Bhattacharyya(Distance):
supports_discrete = False
Expand Down
12 changes: 11 additions & 1 deletion Orange/tests/test_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,13 +962,23 @@ def test_dense_array(self):
def test_sparse_array(self):
data = csr_matrix([[0.5, 0.5], [0, 0.5]])
self.assertAlmostEqual(self.dist(data[0], data[1]), 0.3465735902799726, delta=1e-5)

def test_columns(self):
data = np.array([[0.5, 0.2], [0.5, 0.8]])
true_out = np.array([[0, 0.05268025782891318],
[0.05268025782891318, 0]])
np.testing.assert_array_almost_equal(self.dist(data, axis=0), true_out)

def test_negative_input(self):
a = np.array([0, np.nan])
b = np.array([1, 1])
self.assertRaises(ValueError, self.dist, a, b)
a[1] = -1
self.assertRaises(ValueError, self.dist, a, b)
a = csr_matrix(a)
b = csr_matrix(b)
self.assertRaises(ValueError, self.dist, a, b)


class TestDistances(TestCase):
@classmethod
Expand Down
5 changes: 0 additions & 5 deletions Orange/widgets/unsupervised/owdistances.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from AnyQt.QtCore import Qt
from scipy.sparse import issparse
from numpy import min as _min
import bottleneck as bn

import Orange.data
Expand Down Expand Up @@ -59,7 +58,6 @@ class Error(OWWidget.Error):
dense_metric_sparse_data = Msg("{} requires dense data.")
distances_memory_error = Msg("Not enough memory")
distances_value_error = Msg("Problem in calculation:\n{}")
negative_value_error = Msg("Only non-negative values alowed for Bhattcharyya.")

class Warning(OWWidget.Warning):
ignoring_discrete = Msg("Ignoring categorical features")
Expand Down Expand Up @@ -160,9 +158,6 @@ def _fix_missing():
_fix_discrete, _fix_missing, _fix_nonbinary):
if not check():
return None
if (METRICS[self.metric_idx][0] == 'Bhattacharyya') and _min(data.X) < 0:
self.Error.negative_value_error()
return None
try:
if metric.supports_normalization and self.normalized_dist:
return metric(data, axis=1 - self.axis, impute=True,
Expand Down
6 changes: 3 additions & 3 deletions Orange/widgets/unsupervised/tests/test_owdistances.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ def test_migrates_normalized_dist(self):

def test_negative_values_bhattacharyya(self):
self.iris.X[0, 0] *= -1
for self.widget.metric_idx, (name, _) in enumerate(METRICS):
if name == "Bhattacharyya":
for self.widget.metric_idx, (_, metric) in enumerate(METRICS):
if metric == distance.Bhattacharyya:
break
self.send_signal(self.widget.Inputs.data, self.iris)
self.assertTrue(self.widget.Error.negative_value_error.is_shown())
self.assertTrue(self.widget.Error.distances_value_error.is_shown())
self.iris.X[0, 0] *= -1

0 comments on commit 3e6b549

Please sign in to comment.