Skip to content

Commit

Permalink
infer backend from filename
Browse files Browse the repository at this point in the history
  • Loading branch information
kaizhang committed Nov 10, 2024
1 parent e3d06d8 commit 32e93fe
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 60 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
cd ${GITHUB_WORKSPACE}/anndata-hdf5 && cargo test --no-fail-fast
cd ${GITHUB_WORKSPACE}/anndata && cargo test --no-fail-fast
cd ${GITHUB_WORKSPACE}/anndata-test-utils && cargo test --no-fail-fast
cd ${GITHUB_WORKSPACE}/python && pip install --user .[test]
cd ${GITHUB_WORKSPACE}/python && pip install --user '.[test]'
pytest -v --durations=0 ${GITHUB_WORKSPACE}/python/tests
- name: benchmark
Expand Down
84 changes: 67 additions & 17 deletions anndata/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,40 +92,66 @@ pub trait AnnDataOp {
fn layers(&self) -> Self::AxisArraysRef<'_>;

/// Sets the unstructured data.
fn set_uns<I: Iterator<Item = (String, Data)>>(&self, mut data: I) -> Result<()> {
fn set_uns<I, D>(&self, data: I) -> Result<()>
where
I: IntoIterator<Item = (String, D)>,
D: Into<Data>,
{
self.del_uns()?;
let uns = self.uns();
data.try_for_each(|(k, v)| uns.add(&k, v))
data.into_iter().try_for_each(|(k, v)| uns.add(&k, v))
}

/// Sets the observation matrix.
fn set_obsm<I: Iterator<Item = (String, ArrayData)>>(&self, mut data: I) -> Result<()> {
fn set_obsm<I, D>(&self, data: I) -> Result<()>
where
I: IntoIterator<Item = (String, D)>,
D: Into<ArrayData>,
{
self.del_obsm()?;
let obsm = self.obsm();
data.try_for_each(|(k, v)| obsm.add(&k, v))
data.into_iter().try_for_each(|(k, v)| obsm.add(&k, v))
}

/// Sets the observation pairwise data.
fn set_obsp<I: Iterator<Item = (String, ArrayData)>>(&self, mut data: I) -> Result<()> {
fn set_obsp<I, D>(&self, data: I) -> Result<()>
where
I: IntoIterator<Item = (String, D)>,
D: Into<ArrayData>,
{
self.del_obsp()?;
let obsp = self.obsp();
data.try_for_each(|(k, v)| obsp.add(&k, v))
data.into_iter().try_for_each(|(k, v)| obsp.add(&k, v))
}
/// Sets the variable matrix.
fn set_varm<I: Iterator<Item = (String, ArrayData)>>(&self, mut data: I) -> Result<()> {
fn set_varm<I, D>(&self, data: I) -> Result<()>
where
I: IntoIterator<Item = (String, D)>,
D: Into<ArrayData>,
{
self.del_varm()?;
let varm = self.varm();
data.try_for_each(|(k, v)| varm.add(&k, v))
data.into_iter().try_for_each(|(k, v)| varm.add(&k, v))
}
/// Sets the variable pairwise data.
fn set_varp<I: Iterator<Item = (String, ArrayData)>>(&self, mut data: I) -> Result<()> {
fn set_varp<I, D>(&self, data: I) -> Result<()>
where
I: IntoIterator<Item = (String, D)>,
D: Into<ArrayData>,
{
self.del_varp()?;
let varp = self.varp();
data.try_for_each(|(k, v)| varp.add(&k, v))
data.into_iter().try_for_each(|(k, v)| varp.add(&k, v))
}
/// Sets the layers.
fn set_layers<I: Iterator<Item = (String, ArrayData)>>(&self, mut data: I) -> Result<()> {
fn set_layers<I, D>(&self, data: I) -> Result<()>
where
I: IntoIterator<Item = (String, D)>,
D: Into<ArrayData>,
{
self.del_layers()?;
let layers = self.layers();
data.try_for_each(|(k, v)| layers.add(&k, v))
data.into_iter().try_for_each(|(k, v)| layers.add(&k, v))
}

/// Deletes the unstructured data.
Expand Down Expand Up @@ -520,19 +546,43 @@ impl<B: Backend> AnnDataOp for AnnDataSet<B> {
self.annotation.layers()
}

fn set_uns<I: Iterator<Item = (String, Data)>>(&self, data: I) -> Result<()> {
fn set_uns<I, D>(&self, data: I) -> Result<()>
where
I: IntoIterator<Item = (String, D)>,
D: Into<Data>,
{
self.annotation.set_uns(data)
}
fn set_obsm<I: Iterator<Item = (String, ArrayData)>>(&self, data: I) -> Result<()> {

fn set_obsm<I, D>(&self, data: I) -> Result<()>
where
I: IntoIterator<Item = (String, D)>,
D: Into<ArrayData>,
{
self.annotation.set_obsm(data)
}
fn set_obsp<I: Iterator<Item = (String, ArrayData)>>(&self, data: I) -> Result<()> {

fn set_obsp<I, D>(&self, data: I) -> Result<()>
where
I: IntoIterator<Item = (String, D)>,
D: Into<ArrayData>,
{
self.annotation.set_obsp(data)
}
fn set_varm<I: Iterator<Item = (String, ArrayData)>>(&self, data: I) -> Result<()> {

fn set_varm<I, D>(&self, data: I) -> Result<()>
where
I: IntoIterator<Item = (String, D)>,
D: Into<ArrayData>,
{
self.annotation.set_varm(data)
}
fn set_varp<I: Iterator<Item = (String, ArrayData)>>(&self, data: I) -> Result<()> {

fn set_varp<I, D>(&self, data: I) -> Result<()>
where
I: IntoIterator<Item = (String, D)>,
D: Into<ArrayData>,
{
self.annotation.set_varp(data)
}

Expand Down
43 changes: 32 additions & 11 deletions pyanndata/src/anndata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,25 @@ use anndata;
use anndata::Backend;
use anndata_hdf5::H5;
use pyo3::prelude::*;
use std::{path::PathBuf, collections::HashMap};
use std::{collections::HashMap, path::{Path, PathBuf}};
use anyhow::Result;

pub(crate) fn get_backend<P: AsRef<Path>>(filename: P, backend: Option<&str>) -> &str {
if let Some(backend) = backend {
backend
} else {
if let Some(ext) = filename.as_ref().extension() {
match ext.to_str().unwrap() {
"zarr" | "zarrs" => Zarr::NAME,
"h5ad" | "h5" | "h5ads" => H5::NAME,
_ => H5::NAME,
}
} else {
H5::NAME
}
}
}

/// Read `.h5ad`-formatted hdf5 file.
///
/// Parameters
Expand All @@ -29,12 +45,15 @@ use anyhow::Result;
/// backend: Literal['hdf5', 'zarr']
#[pyfunction]
#[pyo3(
signature = (filename, backed="r+", backend=H5::NAME),
text_signature = "(filename, backed='r+', backend='hdf5')",
signature = (filename, backed="r+", backend=None),
text_signature = "(filename, backed='r+', backend=None)",
)]
pub fn read<'py>(py: Python<'py>, filename: PathBuf, backed: Option<&str>, backend: &str) -> Result<PyObject> {
pub fn read<'py>(py: Python<'py>, filename: PathBuf, backed: Option<&str>, backend: Option<&str>) -> Result<PyObject> {
let adata = match backed {
Some(m) => AnnData::new_from(filename, m, backend).unwrap().into_py(py),
Some(m) => {
let backend = get_backend(&filename, backend);
AnnData::new_from(filename, m, backend).unwrap().into_py(py)
},
None => PyModule::import_bound(py, "anndata")?
.getattr("read_h5ad")?
.call1((filename,))?
Expand Down Expand Up @@ -93,16 +112,16 @@ pub fn concat<'py>(
/// Sorted input matrix can be read faster.
#[pyfunction]
#[pyo3(
signature = (mtx_file, *, obs_names=None, var_names=None, file=None, backend=H5::NAME, sorted=false),
text_signature = "(mtx_file, *, obs_names=None, var_names=None, file=None, backend='hdf5', sorted=False)",
signature = (mtx_file, *, obs_names=None, var_names=None, file=None, backend=None, sorted=false),
text_signature = "(mtx_file, *, obs_names=None, var_names=None, file=None, backend=None, sorted=False)",
)]
pub fn read_mtx(
py: Python<'_>,
mtx_file: PathBuf,
obs_names: Option<PathBuf>,
var_names: Option<PathBuf>,
file: Option<PathBuf>,
backend: &str,
backend: Option<&str>,
sorted: bool,
) -> Result<PyObject> {
let mut reader = anndata::reader::MMReader::from_path(mtx_file)?;
Expand All @@ -116,6 +135,7 @@ pub fn read_mtx(
reader = reader.is_sorted();
}
if let Some(file) = file {
let backend = get_backend(&file, backend);
match backend {
H5::NAME => {
let adata = anndata::AnnData::<H5>::new(file)?;
Expand Down Expand Up @@ -162,20 +182,21 @@ pub fn read_mtx(
/// AnnDataSet
#[pyfunction]
#[pyo3(
signature = (filename, *, adata_files_update=None, mode="r+", backend=H5::NAME),
text_signature = "(filename, *, adata_files_update=None, mode='r+', backend='hdf5')",
signature = (filename, *, adata_files_update=None, mode="r+", backend=None),
text_signature = "(filename, *, adata_files_update=None, mode='r+', backend=None)",
)]
pub fn read_dataset(
filename: PathBuf,
adata_files_update: Option<LocationUpdate>,
mode: &str,
backend: &str,
backend: Option<&str>,
) -> Result<AnnDataSet> {
let adata_files_update = match adata_files_update {
Some(LocationUpdate::Map(map)) => Some(Ok(map)),
Some(LocationUpdate::Dir(dir)) => Some(Err(dir)),
None => None,
};
let backend = get_backend(&filename, backend);
match backend {
H5::NAME => {
let file = match mode {
Expand Down
34 changes: 19 additions & 15 deletions pyanndata/src/anndata/backed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use std::collections::HashMap;
use std::ops::Deref;
use std::path::PathBuf;

use super::get_backend;

/** An annotated data matrix.
`AnnData` stores a data matrix `X` together with annotations of
Expand Down Expand Up @@ -182,10 +184,10 @@ impl AnnData {
#[pyo3(
signature = (
*, filename, X=None, obs=None, var=None, obsm=None,
varm=None, uns=None, backend=H5::NAME,
varm=None, uns=None, backend=None,
),
text_signature = "($self, *, filename, X=None, obs=None, var=None, obsm=None,
varm=None, uns=None, backend='hdf5')"
varm=None, uns=None, backend=None)"
)]
pub fn new(
filename: PathBuf,
Expand All @@ -195,8 +197,9 @@ impl AnnData {
obsm: Option<HashMap<String, PyArrayData>>,
varm: Option<HashMap<String, PyArrayData>>,
uns: Option<HashMap<String, PyData>>,
backend: &str,
backend: Option<&str>,
) -> Result<Self> {
let backend = get_backend(&filename, backend);
let adata: AnnData = match backend {
H5::NAME => anndata::AnnData::<H5>::new(filename)?.into(),
Zarr::NAME => anndata::AnnData::<Zarr>::new(filename)?.into(),
Expand Down Expand Up @@ -407,15 +410,15 @@ impl AnnData {
/// is returned. This parameter is ignored when `inplace=True`.
/// inplace: bool
/// Whether to modify the AnnData object in place or return a new AnnData object.
/// backend: str
/// backend: str | None
/// The backend to use. "hdf5" or "zarr" are supported.
///
/// Returns
/// -------
/// Optional[AnnData]
#[pyo3(
signature = (obs_indices=None, var_indices=None, *, out=None, inplace=true, backend=H5::NAME),
text_signature = "($self, obs_indices=None, var_indices=None, *, out=None, inplace=True, backend='hdf5')",
signature = (obs_indices=None, var_indices=None, *, out=None, inplace=true, backend=None),
text_signature = "($self, obs_indices=None, var_indices=None, *, out=None, inplace=True, backend=None)",
)]
pub fn subset(
&self,
Expand All @@ -424,7 +427,7 @@ impl AnnData {
var_indices: Option<&Bound<'_, PyAny>>,
out: Option<PathBuf>,
inplace: bool,
backend: &str,
backend: Option<&str>,
) -> Result<Option<PyObject>> {
let i = obs_indices
.map(|x| self.select_obs(x).unwrap())
Expand Down Expand Up @@ -594,7 +597,7 @@ trait AnnDataTrait: Send + Downcast {
slice: &[SelectInfoElem],
file: Option<PathBuf>,
inplace: bool,
backend: &str,
backend: Option<&str>,
) -> Result<Option<PyObject>>;

fn chunked_x(&self, chunk_size: usize) -> PyChunkedArray;
Expand Down Expand Up @@ -801,7 +804,7 @@ impl<B: Backend> AnnDataTrait for InnerAnnData<B> {
fn set_uns(&self, uns: Option<HashMap<String, PyData>>) -> Result<()> {
let inner = self.adata.inner();
if let Some(u) = uns {
inner.set_uns(u.into_iter().map(|(k, v)| (k, v.into())))?;
inner.set_uns(u)?;
} else {
inner.del_uns()?;
}
Expand All @@ -810,7 +813,7 @@ impl<B: Backend> AnnDataTrait for InnerAnnData<B> {
fn set_obsm(&self, obsm: Option<HashMap<String, PyArrayData>>) -> Result<()> {
let inner = self.adata.inner();
if let Some(o) = obsm {
inner.set_obsm(o.into_iter().map(|(k, v)| (k, v.into())))?;
inner.set_obsm(o)?;
} else {
inner.del_obsm()?;
}
Expand All @@ -819,7 +822,7 @@ impl<B: Backend> AnnDataTrait for InnerAnnData<B> {
fn set_obsp(&self, obsp: Option<HashMap<String, PyArrayData>>) -> Result<()> {
let inner = self.adata.inner();
if let Some(o) = obsp {
inner.set_obsp(o.into_iter().map(|(k, v)| (k, v.into())))?;
inner.set_obsp(o)?;
} else {
inner.del_obsp()?;
}
Expand All @@ -828,7 +831,7 @@ impl<B: Backend> AnnDataTrait for InnerAnnData<B> {
fn set_varm(&self, varm: Option<HashMap<String, PyArrayData>>) -> Result<()> {
let inner = self.adata.inner();
if let Some(v) = varm {
inner.set_varm(v.into_iter().map(|(k, v)| (k, v.into())))?;
inner.set_varm(v)?;
} else {
inner.del_varm()?;
}
Expand All @@ -837,7 +840,7 @@ impl<B: Backend> AnnDataTrait for InnerAnnData<B> {
fn set_varp(&self, varp: Option<HashMap<String, PyArrayData>>) -> Result<()> {
let inner = self.adata.inner();
if let Some(v) = varp {
inner.set_varp(v.into_iter().map(|(k, v)| (k, v.into())))?;
inner.set_varp(v)?;
} else {
inner.del_varp()?;
}
Expand All @@ -846,7 +849,7 @@ impl<B: Backend> AnnDataTrait for InnerAnnData<B> {
fn set_layers(&self, varp: Option<HashMap<String, PyArrayData>>) -> Result<()> {
let inner = self.adata.inner();
if let Some(v) = varp {
inner.set_layers(v.into_iter().map(|(k, v)| (k, v.into())))?;
inner.set_layers(v)?;
} else {
inner.del_layers()?;
}
Expand All @@ -859,13 +862,14 @@ impl<B: Backend> AnnDataTrait for InnerAnnData<B> {
slice: &[SelectInfoElem],
file: Option<PathBuf>,
inplace: bool,
backend: &str,
backend: Option<&str>,
) -> Result<Option<PyObject>> {
let inner = self.adata.inner();
if inplace {
inner.subset(slice)?;
Ok(None)
} else if let Some(file) = file {
let backend = get_backend(&file, backend);
match backend {
H5::NAME => {
inner.write_select::<H5, _, _>(slice, &file)?;
Expand Down
Loading

0 comments on commit 32e93fe

Please sign in to comment.