Skip to content

Commit

Permalink
warn if not csr
Browse files Browse the repository at this point in the history
  • Loading branch information
adamgayoso committed Oct 14, 2020
1 parent d1bf2ff commit 3303ca7
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 26 deletions.
6 changes: 3 additions & 3 deletions scvi/data/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,9 @@ def _verify_and_correct_data_format(adata, data_registry):
for k in keys:
data = get_from_registry(adata, k)
if isspmatrix(data) and (data.getformat() != "csr"):
logger.debug("{} is csc_matrix. Overwriting to csr_matrix.".format(k))
data = data.tocsr()
_set_data_in_registry(adata, data, k)
logger.warning(
"Training will be faster when sparse matrix is formatted as CSR. It is safe to cast before model initialization."
)
elif isinstance(data, np.ndarray) and (data.flags["C_CONTIGUOUS"] is False):
logger.debug(
"{} is not C_CONTIGUOUS. Overwriting to C_CONTIGUOUS.".format(k)
Expand Down
23 changes: 0 additions & 23 deletions tests/dataset/test_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
)
from scvi import _CONSTANTS
from scvi.data._anndata import get_from_registry
from scipy.sparse import csc_matrix


def test_transfer_anndata_setup():
Expand Down Expand Up @@ -124,28 +123,6 @@ def test_data_format():
get_from_registry(adata, _CONSTANTS.PROTEIN_EXP_KEY),
)

# if data is sparse, check that after setup_anndata, data is csr_matrix
adata = synthetic_iid(run_setup_anndata=False)
adata.X = csc_matrix(adata.X)
adata.obsm["protein_expression"] = csc_matrix(adata.obsm["protein_expression"])
old_x = adata.X
old_pro = adata.obsm["protein_expression"]
setup_anndata(adata, protein_expression_obsm_key="protein_expression")

assert adata.X.getformat() == "csr"
assert adata.obsm["protein_expression"].getformat() == "csr"

assert np.array_equal(old_x.toarray(), adata.X.toarray())
assert np.array_equal(old_pro.toarray(), adata.obsm["protein_expression"].toarray())

assert np.array_equal(
adata.X.toarray(), get_from_registry(adata, _CONSTANTS.X_KEY).toarray()
)
assert np.array_equal(
adata.obsm["protein_expression"].toarray(),
get_from_registry(adata, _CONSTANTS.PROTEIN_EXP_KEY).toarray(),
)

# if obsm is dataframe, make it C_CONTIGUOUS if it isnt
adata = synthetic_iid()
pe = np.asfortranarray(adata.obsm["protein_expression"])
Expand Down

0 comments on commit 3303ca7

Please sign in to comment.