diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 049042c..a4a7ed3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,7 +29,7 @@ jobs: cd ${GITHUB_WORKSPACE}/anndata-hdf5 && cargo test --no-fail-fast cd ${GITHUB_WORKSPACE}/anndata && cargo test --no-fail-fast cd ${GITHUB_WORKSPACE}/anndata-test-utils && cargo test --no-fail-fast - cd ${GITHUB_WORKSPACE}/python && pip install --user .[test] + cd ${GITHUB_WORKSPACE}/python && pip install --user '.[test]' pytest -v --durations=0 ${GITHUB_WORKSPACE}/python/tests - name: benchmark diff --git a/anndata/src/traits.rs b/anndata/src/traits.rs index b9672fa..ab12c29 100644 --- a/anndata/src/traits.rs +++ b/anndata/src/traits.rs @@ -92,40 +92,66 @@ pub trait AnnDataOp { fn layers(&self) -> Self::AxisArraysRef<'_>; /// Sets the unstructured data. - fn set_uns>(&self, mut data: I) -> Result<()> { + fn set_uns(&self, data: I) -> Result<()> + where + I: IntoIterator, + D: Into, + { self.del_uns()?; let uns = self.uns(); - data.try_for_each(|(k, v)| uns.add(&k, v)) + data.into_iter().try_for_each(|(k, v)| uns.add(&k, v)) } + /// Sets the observation matrix. - fn set_obsm>(&self, mut data: I) -> Result<()> { + fn set_obsm(&self, data: I) -> Result<()> + where + I: IntoIterator, + D: Into, + { self.del_obsm()?; let obsm = self.obsm(); - data.try_for_each(|(k, v)| obsm.add(&k, v)) + data.into_iter().try_for_each(|(k, v)| obsm.add(&k, v)) } + /// Sets the observation pairwise data. - fn set_obsp>(&self, mut data: I) -> Result<()> { + fn set_obsp(&self, data: I) -> Result<()> + where + I: IntoIterator, + D: Into, + { self.del_obsp()?; let obsp = self.obsp(); - data.try_for_each(|(k, v)| obsp.add(&k, v)) + data.into_iter().try_for_each(|(k, v)| obsp.add(&k, v)) } /// Sets the variable matrix. - fn set_varm>(&self, mut data: I) -> Result<()> { + fn set_varm(&self, data: I) -> Result<()> + where + I: IntoIterator, + D: Into, + { self.del_varm()?; let varm = self.varm(); - data.try_for_each(|(k, v)| varm.add(&k, v)) + data.into_iter().try_for_each(|(k, v)| varm.add(&k, v)) } /// Sets the variable pairwise data. - fn set_varp>(&self, mut data: I) -> Result<()> { + fn set_varp(&self, data: I) -> Result<()> + where + I: IntoIterator, + D: Into, + { self.del_varp()?; let varp = self.varp(); - data.try_for_each(|(k, v)| varp.add(&k, v)) + data.into_iter().try_for_each(|(k, v)| varp.add(&k, v)) } /// Sets the layers. - fn set_layers>(&self, mut data: I) -> Result<()> { + fn set_layers(&self, data: I) -> Result<()> + where + I: IntoIterator, + D: Into, + { self.del_layers()?; let layers = self.layers(); - data.try_for_each(|(k, v)| layers.add(&k, v)) + data.into_iter().try_for_each(|(k, v)| layers.add(&k, v)) } /// Deletes the unstructured data. @@ -520,19 +546,43 @@ impl AnnDataOp for AnnDataSet { self.annotation.layers() } - fn set_uns>(&self, data: I) -> Result<()> { + fn set_uns(&self, data: I) -> Result<()> + where + I: IntoIterator, + D: Into, + { self.annotation.set_uns(data) } - fn set_obsm>(&self, data: I) -> Result<()> { + + fn set_obsm(&self, data: I) -> Result<()> + where + I: IntoIterator, + D: Into, + { self.annotation.set_obsm(data) } - fn set_obsp>(&self, data: I) -> Result<()> { + + fn set_obsp(&self, data: I) -> Result<()> + where + I: IntoIterator, + D: Into, + { self.annotation.set_obsp(data) } - fn set_varm>(&self, data: I) -> Result<()> { + + fn set_varm(&self, data: I) -> Result<()> + where + I: IntoIterator, + D: Into, + { self.annotation.set_varm(data) } - fn set_varp>(&self, data: I) -> Result<()> { + + fn set_varp(&self, data: I) -> Result<()> + where + I: IntoIterator, + D: Into, + { self.annotation.set_varp(data) } diff --git a/pyanndata/src/anndata.rs b/pyanndata/src/anndata.rs index 30a182b..2fd37b6 100644 --- a/pyanndata/src/anndata.rs +++ b/pyanndata/src/anndata.rs @@ -11,9 +11,25 @@ use anndata; use anndata::Backend; use anndata_hdf5::H5; use pyo3::prelude::*; -use std::{path::PathBuf, collections::HashMap}; +use std::{collections::HashMap, path::{Path, PathBuf}}; use anyhow::Result; +pub(crate) fn get_backend>(filename: P, backend: Option<&str>) -> &str { + if let Some(backend) = backend { + backend + } else { + if let Some(ext) = filename.as_ref().extension() { + match ext.to_str().unwrap() { + "zarr" | "zarrs" => Zarr::NAME, + "h5ad" | "h5" | "h5ads" => H5::NAME, + _ => H5::NAME, + } + } else { + H5::NAME + } + } +} + /// Read `.h5ad`-formatted hdf5 file. /// /// Parameters @@ -29,12 +45,15 @@ use anyhow::Result; /// backend: Literal['hdf5', 'zarr'] #[pyfunction] #[pyo3( - signature = (filename, backed="r+", backend=H5::NAME), - text_signature = "(filename, backed='r+', backend='hdf5')", + signature = (filename, backed="r+", backend=None), + text_signature = "(filename, backed='r+', backend=None)", )] -pub fn read<'py>(py: Python<'py>, filename: PathBuf, backed: Option<&str>, backend: &str) -> Result { +pub fn read<'py>(py: Python<'py>, filename: PathBuf, backed: Option<&str>, backend: Option<&str>) -> Result { let adata = match backed { - Some(m) => AnnData::new_from(filename, m, backend).unwrap().into_py(py), + Some(m) => { + let backend = get_backend(&filename, backend); + AnnData::new_from(filename, m, backend).unwrap().into_py(py) + }, None => PyModule::import_bound(py, "anndata")? .getattr("read_h5ad")? .call1((filename,))? @@ -93,8 +112,8 @@ pub fn concat<'py>( /// Sorted input matrix can be read faster. #[pyfunction] #[pyo3( - signature = (mtx_file, *, obs_names=None, var_names=None, file=None, backend=H5::NAME, sorted=false), - text_signature = "(mtx_file, *, obs_names=None, var_names=None, file=None, backend='hdf5', sorted=False)", + signature = (mtx_file, *, obs_names=None, var_names=None, file=None, backend=None, sorted=false), + text_signature = "(mtx_file, *, obs_names=None, var_names=None, file=None, backend=None, sorted=False)", )] pub fn read_mtx( py: Python<'_>, @@ -102,7 +121,7 @@ pub fn read_mtx( obs_names: Option, var_names: Option, file: Option, - backend: &str, + backend: Option<&str>, sorted: bool, ) -> Result { let mut reader = anndata::reader::MMReader::from_path(mtx_file)?; @@ -116,6 +135,7 @@ pub fn read_mtx( reader = reader.is_sorted(); } if let Some(file) = file { + let backend = get_backend(&file, backend); match backend { H5::NAME => { let adata = anndata::AnnData::
::new(file)?; @@ -162,20 +182,21 @@ pub fn read_mtx( /// AnnDataSet #[pyfunction] #[pyo3( - signature = (filename, *, adata_files_update=None, mode="r+", backend=H5::NAME), - text_signature = "(filename, *, adata_files_update=None, mode='r+', backend='hdf5')", + signature = (filename, *, adata_files_update=None, mode="r+", backend=None), + text_signature = "(filename, *, adata_files_update=None, mode='r+', backend=None)", )] pub fn read_dataset( filename: PathBuf, adata_files_update: Option, mode: &str, - backend: &str, + backend: Option<&str>, ) -> Result { let adata_files_update = match adata_files_update { Some(LocationUpdate::Map(map)) => Some(Ok(map)), Some(LocationUpdate::Dir(dir)) => Some(Err(dir)), None => None, }; + let backend = get_backend(&filename, backend); match backend { H5::NAME => { let file = match mode { diff --git a/pyanndata/src/anndata/backed.rs b/pyanndata/src/anndata/backed.rs index afaed41..30b4f8f 100644 --- a/pyanndata/src/anndata/backed.rs +++ b/pyanndata/src/anndata/backed.rs @@ -18,6 +18,8 @@ use std::collections::HashMap; use std::ops::Deref; use std::path::PathBuf; +use super::get_backend; + /** An annotated data matrix. `AnnData` stores a data matrix `X` together with annotations of @@ -182,10 +184,10 @@ impl AnnData { #[pyo3( signature = ( *, filename, X=None, obs=None, var=None, obsm=None, - varm=None, uns=None, backend=H5::NAME, + varm=None, uns=None, backend=None, ), text_signature = "($self, *, filename, X=None, obs=None, var=None, obsm=None, - varm=None, uns=None, backend='hdf5')" + varm=None, uns=None, backend=None)" )] pub fn new( filename: PathBuf, @@ -195,8 +197,9 @@ impl AnnData { obsm: Option>, varm: Option>, uns: Option>, - backend: &str, + backend: Option<&str>, ) -> Result { + let backend = get_backend(&filename, backend); let adata: AnnData = match backend { H5::NAME => anndata::AnnData::
::new(filename)?.into(), Zarr::NAME => anndata::AnnData::::new(filename)?.into(), @@ -407,15 +410,15 @@ impl AnnData { /// is returned. This parameter is ignored when `inplace=True`. /// inplace: bool /// Whether to modify the AnnData object in place or return a new AnnData object. - /// backend: str + /// backend: str | None /// The backend to use. "hdf5" or "zarr" are supported. /// /// Returns /// ------- /// Optional[AnnData] #[pyo3( - signature = (obs_indices=None, var_indices=None, *, out=None, inplace=true, backend=H5::NAME), - text_signature = "($self, obs_indices=None, var_indices=None, *, out=None, inplace=True, backend='hdf5')", + signature = (obs_indices=None, var_indices=None, *, out=None, inplace=true, backend=None), + text_signature = "($self, obs_indices=None, var_indices=None, *, out=None, inplace=True, backend=None)", )] pub fn subset( &self, @@ -424,7 +427,7 @@ impl AnnData { var_indices: Option<&Bound<'_, PyAny>>, out: Option, inplace: bool, - backend: &str, + backend: Option<&str>, ) -> Result> { let i = obs_indices .map(|x| self.select_obs(x).unwrap()) @@ -594,7 +597,7 @@ trait AnnDataTrait: Send + Downcast { slice: &[SelectInfoElem], file: Option, inplace: bool, - backend: &str, + backend: Option<&str>, ) -> Result>; fn chunked_x(&self, chunk_size: usize) -> PyChunkedArray; @@ -801,7 +804,7 @@ impl AnnDataTrait for InnerAnnData { fn set_uns(&self, uns: Option>) -> Result<()> { let inner = self.adata.inner(); if let Some(u) = uns { - inner.set_uns(u.into_iter().map(|(k, v)| (k, v.into())))?; + inner.set_uns(u)?; } else { inner.del_uns()?; } @@ -810,7 +813,7 @@ impl AnnDataTrait for InnerAnnData { fn set_obsm(&self, obsm: Option>) -> Result<()> { let inner = self.adata.inner(); if let Some(o) = obsm { - inner.set_obsm(o.into_iter().map(|(k, v)| (k, v.into())))?; + inner.set_obsm(o)?; } else { inner.del_obsm()?; } @@ -819,7 +822,7 @@ impl AnnDataTrait for InnerAnnData { fn set_obsp(&self, obsp: Option>) -> Result<()> { let inner = self.adata.inner(); if let Some(o) = obsp { - inner.set_obsp(o.into_iter().map(|(k, v)| (k, v.into())))?; + inner.set_obsp(o)?; } else { inner.del_obsp()?; } @@ -828,7 +831,7 @@ impl AnnDataTrait for InnerAnnData { fn set_varm(&self, varm: Option>) -> Result<()> { let inner = self.adata.inner(); if let Some(v) = varm { - inner.set_varm(v.into_iter().map(|(k, v)| (k, v.into())))?; + inner.set_varm(v)?; } else { inner.del_varm()?; } @@ -837,7 +840,7 @@ impl AnnDataTrait for InnerAnnData { fn set_varp(&self, varp: Option>) -> Result<()> { let inner = self.adata.inner(); if let Some(v) = varp { - inner.set_varp(v.into_iter().map(|(k, v)| (k, v.into())))?; + inner.set_varp(v)?; } else { inner.del_varp()?; } @@ -846,7 +849,7 @@ impl AnnDataTrait for InnerAnnData { fn set_layers(&self, varp: Option>) -> Result<()> { let inner = self.adata.inner(); if let Some(v) = varp { - inner.set_layers(v.into_iter().map(|(k, v)| (k, v.into())))?; + inner.set_layers(v)?; } else { inner.del_layers()?; } @@ -859,13 +862,14 @@ impl AnnDataTrait for InnerAnnData { slice: &[SelectInfoElem], file: Option, inplace: bool, - backend: &str, + backend: Option<&str>, ) -> Result> { let inner = self.adata.inner(); if inplace { inner.subset(slice)?; Ok(None) } else if let Some(file) = file { + let backend = get_backend(&file, backend); match backend { H5::NAME => { inner.write_select::(slice, &file)?; diff --git a/pyanndata/src/anndata/dataset.rs b/pyanndata/src/anndata/dataset.rs index 6ab2d26..5c8651d 100644 --- a/pyanndata/src/anndata/dataset.rs +++ b/pyanndata/src/anndata/dataset.rs @@ -19,6 +19,7 @@ use std::collections::HashMap; use std::path::PathBuf; use super::backed::StackedAnnData; +use super::get_backend; /** Similar to `AnnData`, `AnnDataSet` contains annotations of observations `obs` (`obsm`, `obsp`), variables `var` (`varm`, `varp`), @@ -118,13 +119,14 @@ pub enum AnnDataFile<'py> { #[pymethods] impl AnnDataSet { #[new] - #[pyo3(signature = (adatas, *, filename, add_key="sample", backend=H5::NAME))] + #[pyo3(signature = (adatas, *, filename, add_key="sample", backend=None))] pub fn new( adatas: Vec<(String, AnnDataFile)>, filename: PathBuf, add_key: &str, - backend: &str, + backend: Option<&str>, ) -> Result { + let backend = get_backend(&filename, backend); match backend { H5::NAME => { let anndatas = adatas.into_iter().map(|(key, data_file)| { @@ -335,19 +337,21 @@ impl AnnDataSet { /// If the order of input `obs_indices` has been changed, it will /// return the indices that would sort the `obs_indices` array. #[pyo3( - signature = (obs_indices=None, var_indices=None, out=None, backend=H5::NAME), - text_signature = "($self, obs_indices=None, var_indices=None, out=None, backend='hdf5')" + signature = (obs_indices=None, var_indices=None, out=None, backend=None), + text_signature = "($self, obs_indices=None, var_indices=None, out=None, backend=None)" )] pub fn subset( &self, obs_indices: Option<&Bound<'_, PyAny>>, var_indices: Option<&Bound<'_, PyAny>>, out: Option, - backend: &str, + backend: Option<&str>, ) -> Result<(AnnDataSet, Option>)> { if out.is_none() { bail!("AnnDataSet cannot be subsetted in place. Please provide an output directory."); } + let out = out.unwrap(); + let backend = get_backend(&out, backend); let i = obs_indices .map(|x| self.select_obs(x).unwrap()) .unwrap_or(SelectInfoElem::full()); @@ -355,7 +359,7 @@ impl AnnDataSet { .map(|x| self.select_var(x).unwrap()) .unwrap_or(SelectInfoElem::full()); self.0 - .subset(&[i, j], out.unwrap(), backend) + .subset(&[i, j], out, backend) } /// View into the component AnnData objects. @@ -370,8 +374,8 @@ impl AnnDataSet { /// Convert AnnDataSet to AnnData object. #[pyo3( - signature = (obs_indices=None, var_indices=None, copy_x=true, file=None, backend=H5::NAME), - text_signature = "($self, obs_indices=None, var_indices=None, copy_x=True, file=None, backed='hdf5')", + signature = (obs_indices=None, var_indices=None, copy_x=true, file=None, backend=None), + text_signature = "($self, obs_indices=None, var_indices=None, copy_x=True, file=None, backed=None)", )] pub fn to_adata( &self, @@ -380,7 +384,7 @@ impl AnnDataSet { var_indices: Option<&Bound<'_, PyAny>>, copy_x: bool, file: Option, - backend: &str, + backend: Option<&str>, ) -> Result { let i = obs_indices .map(|x| self.select_obs(x).unwrap()) @@ -493,7 +497,7 @@ trait AnnDataSetTrait: Send + Downcast { slice: &[SelectInfoElem], copy_x: bool, file: Option, - backend: &str, + backend: Option<&str>, ) -> Result; fn chunked_x(&self, chunk_size: usize) -> PyChunkedArray; @@ -647,7 +651,7 @@ impl AnnDataSetTrait for Slot> { fn set_uns(&self, uns: Option>) -> Result<()> { let inner = self.inner(); if let Some(u) = uns { - inner.set_uns(u.into_iter().map(|(k, v)| (k, v.into())))?; + inner.set_uns(u)?; } else { inner.del_uns()?; } @@ -656,7 +660,7 @@ impl AnnDataSetTrait for Slot> { fn set_obsm(&self, obsm: Option>) -> Result<()> { let inner = self.inner(); if let Some(o) = obsm { - inner.set_obsm(o.into_iter().map(|(k, v)| (k, v.into())))?; + inner.set_obsm(o)?; } else { inner.del_obsm()?; } @@ -665,7 +669,7 @@ impl AnnDataSetTrait for Slot> { fn set_obsp(&self, obsp: Option>) -> Result<()> { let inner = self.inner(); if let Some(o) = obsp { - inner.set_obsp(o.into_iter().map(|(k, v)| (k, v.into())))?; + inner.set_obsp(o)?; } else { inner.del_obsp()?; } @@ -674,7 +678,7 @@ impl AnnDataSetTrait for Slot> { fn set_varm(&self, varm: Option>) -> Result<()> { let inner = self.inner(); if let Some(v) = varm { - inner.set_varm(v.into_iter().map(|(k, v)| (k, v.into())))?; + inner.set_varm(v)?; } else { inner.del_varm()?; } @@ -683,7 +687,7 @@ impl AnnDataSetTrait for Slot> { fn set_varp(&self, varp: Option>) -> Result<()> { let inner = self.inner(); if let Some(v) = varp { - inner.set_varp(v.into_iter().map(|(k, v)| (k, v.into())))?; + inner.set_varp(v)?; } else { inner.del_varp()?; } @@ -721,10 +725,11 @@ impl AnnDataSetTrait for Slot> { slice: &[SelectInfoElem], copy_x: bool, file: Option, - backend: &str, + backend: Option<&str>, ) -> Result { let inner = self.inner(); if let Some(file) = file { + let backend = get_backend(&file, backend); match backend { H5::NAME => inner .to_adata_select::(slice, file, copy_x) diff --git a/python/pyproject.toml b/python/pyproject.toml index beab2f8..79e4072 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -34,4 +34,5 @@ dependencies = [ "polars>=1.9", ] +[project.dev-dependencies] test = ["pytest", "hypothesis==6.72.4"] \ No newline at end of file