Skip to content

Commit

Permalink
upgrade polars
Browse files Browse the repository at this point in the history
  • Loading branch information
kaizhang committed Dec 9, 2024
1 parent 1423ff5 commit 99aa68c
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 63 deletions.
2 changes: 1 addition & 1 deletion anndata-test-utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ edition = "2021"

[dependencies]
anyhow = "1.0"
ndarray = { version = "0.16" }
ndarray = "0.16"
anndata = { workspace = true }
num = "0.4"
tempfile = "3.2"
Expand Down
2 changes: 1 addition & 1 deletion anndata-zarr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ anndata = { workspace = true }
serde_json = "1.0"
anyhow = "1.0"
ndarray = { version = "0.16", features = ["serde"] }
zarrs = "0.17"
zarrs = "0.18"
smallvec = "1.13"

[dev-dependencies]
Expand Down
4 changes: 1 addition & 3 deletions anndata/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ itertools = "0.13"
ndarray = "0.16"
nalgebra-sparse = "0.10"
num = "0.4"
polars = { version = "=0.43.0", features = ["lazy", "decompress-fast", "ndarray", "dtype-full"] }
polars = { version = "0.45.0", features = ["lazy", "decompress-fast", "ndarray", "dtype-full"] }
paste = "1.0"
parking_lot = "0.12"
smallvec = "1.13"
Expand All @@ -28,8 +28,6 @@ serde_json = "1.0"
rayon = "1.10"
permutation = "0.4"

hashbrown = { version = "0.14.5", features = ["raw"] }

[dev-dependencies]
tempfile = "3.2"
proptest = "1"
Expand Down
134 changes: 97 additions & 37 deletions anndata/src/anndata/dataset.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
use crate::{
traits::{AnnDataOp, ElemCollectionOp},
anndata::AnnData,
backend::Backend,
container::{Slot, Dim, Axis, AxisArrays, StackedArrayElem, StackedAxisArrays, StackedDataFrame},
data::*,
container::{
Axis, AxisArrays, Dim, Slot, StackedArrayElem, StackedAxisArrays, StackedDataFrame,
},
data::index::VecVecIndex,
data::*,
traits::{AnnDataOp, ElemCollectionOp},
};

use anyhow::{anyhow, bail, ensure, Context, Result};
use indexmap::map::IndexMap;
use itertools::Itertools;
use polars::{df, prelude::{DataFrame, NamedFrom, Series}};
use polars::{
df,
prelude::{Column, DataFrame},
};
use rayon::iter::{IndexedParallelIterator, ParallelIterator};
use std::{collections::{HashMap, HashSet}, path::{Path, PathBuf}};
use std::{
collections::{HashMap, HashSet},
path::{Path, PathBuf},
};

pub struct AnnDataSet<B: Backend> {
pub(crate) annotation: AnnData<B>,
Expand Down Expand Up @@ -140,7 +148,8 @@ impl<B: Backend> AnnDataSet<B> {
let mut annotation = AnnData::new(filename)?;
annotation.n_obs = Dim::new(n_obs);
annotation.n_vars = Dim::new(n_vars);
{ // Set UNS. UNS includes children anndata locations and shared elements.
{
// Set UNS. UNS includes children anndata locations and shared elements.
let (keys, filenames): (Vec<_>, Vec<_>) = anndatas
.iter()
.map(|(k, v)| (k.clone(), v.filename().display().to_string()))
Expand All @@ -155,13 +164,30 @@ impl<B: Backend> AnnDataSet<B> {
.reduce(|a, b| a.intersection(&b).cloned().collect())
.unwrap_or(HashSet::new());
for key in shared_keys {
if anndatas.values().map(|x| x.uns().get_item::<Data>(&key).unwrap().unwrap()).all_equal() {
annotation.uns().add(&key, anndatas.values().next().unwrap().uns().get_item::<Data>(&key)?.unwrap())?;
if anndatas
.values()
.map(|x| x.uns().get_item::<Data>(&key).unwrap().unwrap())
.all_equal()
{
annotation.uns().add(
&key,
anndatas
.values()
.next()
.unwrap()
.uns()
.get_item::<Data>(&key)?
.unwrap(),
)?;
}
}
}
{ // Set OBS.
let obs_names: DataFrameIndex = anndatas.values().flat_map(|x| x.obs_names().into_iter()).collect();
{
// Set OBS.
let obs_names: DataFrameIndex = anndatas
.values()
.flat_map(|x| x.obs_names().into_iter())
.collect();
if !obs_names.is_empty() && obs_names.len() == n_obs {
annotation.set_obs_names(obs_names)?;
}
Expand All @@ -172,7 +198,8 @@ impl<B: Backend> AnnDataSet<B> {
.collect::<Vec<_>>();
annotation.set_obs(df!(add_key => keys)?)?;
}
{ // Set VAR.
{
// Set VAR.
let adata = anndatas.values().next().unwrap();
let var_names = adata.var_names();
if !var_names.is_empty() {
Expand All @@ -187,7 +214,7 @@ impl<B: Backend> AnnDataSet<B> {

pub fn open<P: AsRef<Path>>(
file: B::Store,
adata_files_update: Option<Result<HashMap<String, P>, P>>
adata_files_update: Option<Result<HashMap<String, P>, P>>,
) -> Result<Self> {
let annotation: AnnData<B> = AnnData::open(file)?;
let file_path = annotation
Expand Down Expand Up @@ -231,17 +258,20 @@ impl<B: Backend> AnnDataSet<B> {
selection: S,
dir: P,
) -> Result<Option<Vec<usize>>> {
selection.as_ref()[0].bound_check(self.n_obs())
selection.as_ref()[0]
.bound_check(self.n_obs())
.map_err(|e| anyhow!("AnnDataSet obs {}", e))?;
selection.as_ref()[1].bound_check(self.n_vars())
selection.as_ref()[1]
.bound_check(self.n_vars())
.map_err(|e| anyhow!("AnnDataSet var {}", e))?;

let file = dir.as_ref().join("_dataset.h5ads");
let anndata_dir = dir.as_ref().join("anndatas");
std::fs::create_dir_all(&anndata_dir)?;

let (files, obs_idx_order) =
self.anndatas.inner()
self.anndatas
.inner()
.write_select::<O, _, _>(&selection, &anndata_dir, ".h5ad")?;

if let Some(order) = obs_idx_order.as_ref() {
Expand Down Expand Up @@ -276,8 +306,9 @@ impl<B: Backend> AnnDataSet<B> {
self.annotation.write::<O, _>(&out)?;
let adata = AnnData::open(O::open_rw(&out)?)?;
if copy_x {
adata
.set_x_from_iter::<_, ArrayData>(self.anndatas.inner().x.chunked(500).map(|x| x.0))?;
adata.set_x_from_iter::<_, ArrayData>(
self.anndatas.inner().x.chunked(500).map(|x| x.0),
)?;
}
Ok(adata)
}
Expand All @@ -300,8 +331,9 @@ impl<B: Backend> AnnDataSet<B> {
/// Convert AnnDataSet to AnnData object
pub fn into_adata(self, copy_x: bool) -> Result<AnnData<B>> {
if copy_x {
self.annotation
.set_x_from_iter::<_, ArrayData>(self.anndatas.inner().x.chunked(500).map(|x| x.0))?;
self.annotation.set_x_from_iter::<_, ArrayData>(
self.anndatas.inner().x.chunked(500).map(|x| x.0),
)?;
}
for ann in self.anndatas.extract().unwrap().elems.into_values() {
ann.close()?;
Expand All @@ -324,7 +356,8 @@ fn update_anndata_locations_by_map<B: Backend, P: AsRef<Path>>(
new_locations: HashMap<String, P>,
) -> Result<Vec<(String, PathBuf)>> {
let df: DataFrame = ann
.uns().get_item("AnnDataSet")?
.uns()
.get_item("AnnDataSet")?
.context("key 'AnnDataSet' is not present")?;
let keys = df.column("keys").unwrap();
let filenames = as_str_vec(df.column("file_path")?);
Expand All @@ -338,10 +371,17 @@ fn update_anndata_locations_by_map<B: Backend, P: AsRef<Path>>(
(k.to_string(), name)
})
.collect();
let data = DataFrame::new(
vec![keys.clone(),
Series::new("file_path".into(), new_files.iter().map(|x| x.1.to_str().unwrap().to_string()).collect::<Vec<_>>())]
).unwrap();
let data = DataFrame::new(vec![
keys.clone(),
Column::new(
"file_path".into(),
new_files
.iter()
.map(|x| x.1.to_str().unwrap().to_string())
.collect::<Vec<_>>(),
),
])
.unwrap();
if !new_locations.is_empty() {
ann.uns().add("AnnDataSet", data)?;
}
Expand All @@ -353,26 +393,39 @@ fn update_anndata_location_dir<B: Backend, P: AsRef<Path>>(
dir: P,
) -> Result<Vec<(String, PathBuf)>> {
let df: DataFrame = ann
.uns().get_item("AnnDataSet")?
.uns()
.get_item("AnnDataSet")?
.context("key 'AnnDataSet' is not present")?;
let keys = df.column("keys").unwrap();
let file_map: HashMap<String, PathBuf> = std::fs::read_dir(dir)?.map(|x| x.map(|entry|
(entry.file_name().into_string().unwrap(), entry.path())
)).collect::<Result<_, std::io::Error>>()?;
let file_map: HashMap<String, PathBuf> = std::fs::read_dir(dir)?
.map(|x| x.map(|entry| (entry.file_name().into_string().unwrap(), entry.path())))
.collect::<Result<_, std::io::Error>>()?;
let filenames = as_str_vec(df.column("file_path")?);
let new_files: Vec<_> = as_str_vec(keys)
.into_iter()
.zip(filenames)
.map(|(k, filename)| {
let path = PathBuf::from(filename);
let name = path.file_name().unwrap().to_str().unwrap();
(k, file_map.get(name).map_or(path, |x| std::fs::canonicalize(x).unwrap()))
(
k,
file_map
.get(name)
.map_or(path, |x| std::fs::canonicalize(x).unwrap()),
)
})
.collect();
let data = DataFrame::new(
vec![keys.clone(),
Series::new("file_path".into(), new_files.iter().map(|x| x.1.to_str().unwrap().to_string()).collect::<Vec<_>>())]
).unwrap();
let data = DataFrame::new(vec![
keys.clone(),
Column::new(
"file_path".into(),
new_files
.iter()
.map(|x| x.1.to_str().unwrap().to_string())
.collect::<Vec<_>>(),
),
])
.unwrap();
ann.uns().add("AnnDataSet", data)?;
Ok(new_files)
}
Expand Down Expand Up @@ -520,10 +573,17 @@ impl<B: Backend> StackedAnnData<B> {
}
}

fn as_str_vec(series: &Series) -> Vec<String> {
fn as_str_vec(series: &Column) -> Vec<String> {
if let Ok(s) = series.str() {
s.into_iter().map(|x| x.unwrap().to_string()).collect::<Vec<_>>()
s.into_iter()
.map(|x| x.unwrap().to_string())
.collect::<Vec<_>>()
} else {
series.categorical().unwrap().iter_str().map(|x| x.unwrap().to_string()).collect::<Vec<_>>()
series
.categorical()
.unwrap()
.iter_str()
.map(|x| x.unwrap().to_string())
.collect::<Vec<_>>()
}
}
}
16 changes: 9 additions & 7 deletions anndata/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use itertools::Itertools;
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::pattern::SparsityPattern;
use polars::frame::DataFrame;
use polars::prelude::{AnyValue, CategoricalChunkedBuilder, DataType, IntoLazy, NamedFrom};
use polars::prelude::{AnyValue, CategoricalChunkedBuilder, Column, DataType, IntoLazy, NamedFrom};
use polars::series::{IntoSeries, Series};

use crate::data::{ArrayData, DynArray};
Expand Down Expand Up @@ -152,8 +152,10 @@ fn merge_df(this: &mut DataFrame, other: &DataFrame) -> Result<()> {
if let Some(i) = this.get_column_index(name) {
let this_s = this.column(name)?;
let new_column = this_s
.as_series()
.unwrap()
.iter()
.zip(other_s.iter())
.zip(other_s.as_series().unwrap().iter())
.map(|(this_v, other_v)| {
if other_v.is_null() {
this_v.clone()
Expand Down Expand Up @@ -186,12 +188,12 @@ fn merge_df(this: &mut DataFrame, other: &DataFrame) -> Result<()> {
Ok(())
}

/// Reorganize a series to match the new row names, filling in missing values with `None`.
/// Reorganize a column to match the new row names, filling in missing values with `None`.
fn align_series(
series: &Series,
series: &Column,
row_names: &DataFrameIndex,
new_row_names: &IndexSet<String>,
) -> Result<Series> {
) -> Result<Column> {
let name = series.name();
let new_series = match series.dtype() {
DataType::Categorical(_, ord) => {
Expand Down Expand Up @@ -225,7 +227,7 @@ fn align_series(
Series::from_any_values_and_dtype(name.clone(), &values?, &dtype, false)?
}
};
Ok(new_series)
Ok(new_series.into())
}

fn index_array(
Expand Down Expand Up @@ -274,4 +276,4 @@ fn index_array(
ArrayData::CsrMatrix(x) => crate::macros::dyn_map!(x, DynCsrMatrix, fun_csr),
_ => todo!(),
}
}
}
6 changes: 3 additions & 3 deletions anndata/src/container/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use num::integer::div_rem;
use parking_lot::{Mutex, MutexGuard};
use polars::{
frame::DataFrame,
prelude::{concat, Series, IntoLazy, UnionArgs},
prelude::{concat, Column, IntoLazy, UnionArgs},
series::IntoSeries,
};
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
Expand Down Expand Up @@ -164,7 +164,7 @@ impl<B: Backend> InnerDataFrameElem<B> {
self.index.len()
}

pub fn column(&mut self, name: &str) -> Result<&Series> {
pub fn column(&mut self, name: &str) -> Result<&Column> {
self.data().and_then(|x| Ok(x.column(name)?))
}

Expand Down Expand Up @@ -778,7 +778,7 @@ impl<B: Backend> StackedDataFrame<B> {
}

// TODO: this is not efficient, we should use the index to select the columns
pub fn column(&self, name: &str) -> Result<Series> {
pub fn column(&self, name: &str) -> Result<Column> {
if self.column_names.contains(name) {
Ok(self.data()?.column(name)?.clone())
} else {
Expand Down
10 changes: 4 additions & 6 deletions pyanndata/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,10 @@ numpy = "0.22"
ndarray = "0.16"
nalgebra-sparse = "0.10"
hdf5 = { package = "hdf5-metno", version = "0.9" }
polars = { version = "=0.43.0", features = ["ndarray"] }
#pyo3-polars = {version = "0.17", features = ["dtype-full", "dtype-struct"] }
pyo3-polars = { git = "https://github.com/pola-rs/pyo3-polars.git", rev = "d426148ae27410aa4fb10a4a9dc67647a058244f", features = ["dtype-full", "dtype-struct"] }

polars-core = "=0.43.0"
polars-arrow = "=0.43.0"
polars = { version = "0.45.0", features = ["ndarray"] }
pyo3-polars = {version = "0.19", features = ["dtype-full", "dtype-struct"] }
polars-core = "0.45.0"
polars-arrow = "0.45.0"
thiserror = "1.0"
rand = "0.8"
flate2 = "1.0"
Expand Down
Loading

0 comments on commit 99aa68c

Please sign in to comment.