Skip to content

Commit

Permalink
Table: Fix ensure_copy for sparse matrices
Browse files Browse the repository at this point in the history
`ensure_copy` is checking whether an array is a view through the `.base` argument which doesn't exist on `scipy.sparse` matrices. Since `scipy.sparse` don't work with views (e.g. indexing returns a copy of the matrix) the check should only be performed for dense matrices.
  • Loading branch information
nikicc committed Jul 14, 2016
1 parent 4a484f6 commit f30886e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
14 changes: 9 additions & 5 deletions Orange/data/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,15 +912,19 @@ def is_copy(self):

def ensure_copy(self):
"""
Ensure that the table owns its data; copy arrays when necessary
Ensure that the table owns its data; copy arrays when necessary.
The check is skipped for sparse matrices since they don't have views like numpy arrays.
"""
if self.X.base is not None:
def is_view(x):
return not sp.issparse(x) and x.base is not None

if is_view(self.X):
self.X = self.X.copy()
if self._Y.base is not None:
if is_view(self._Y):
self._Y = self._Y.copy()
if self.metas.base is not None:
if is_view(self.metas):
self.metas = self.metas.copy()
if self.W.base is not None:
if is_view(self.W):
self.W = self.W.copy()

def copy(self):
Expand Down
13 changes: 13 additions & 0 deletions Orange/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,19 @@ def test_copy(self):
self.assertFalse(np.all(t.Y == copy.Y))
self.assertFalse(np.all(t.metas == copy.metas))

def test_copy_sparse(self):
t = data.Table('iris')
t.X = csr_matrix(t.X)
copy = t.copy()

self.assertEqual((t.X != copy.X).nnz, 0) # sparse matrices match by content
np.testing.assert_equal(t.Y, copy.Y)
np.testing.assert_equal(t.metas, copy.metas)

self.assertNotEqual(id(t.X), id(copy.X))
self.assertNotEqual(id(t._Y), id(copy._Y))
self.assertNotEqual(id(t.metas), id(copy.metas))

def test_concatenate(self):
d1 = data.Domain([data.ContinuousVariable('a1')])
t1 = data.Table.from_numpy(d1, [[1],
Expand Down

0 comments on commit f30886e

Please sign in to comment.