Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
kaizhang committed Nov 5, 2024
1 parent 5512063 commit 5bc62ce
Show file tree
Hide file tree
Showing 9 changed files with 506 additions and 686 deletions.
6 changes: 3 additions & 3 deletions anndata-test-utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mod common;
pub use common::*;

use anndata::{data::CsrNonCanonical, *};
use data::DynArray;
use data::ArrayConvert;
use nalgebra_sparse::{CooMatrix, CsrMatrix};
use ndarray::Array2;
use proptest::prelude::*;
Expand Down Expand Up @@ -68,8 +68,8 @@ where
let arr2 = Array2::<i32>::zeros((10, 20));
assert!(adata.obsm().add("test", &arr2).is_err());

// Automatical data type casting
let _: Array2<f64> = adata.x().get::<DynArray>().unwrap().unwrap().try_into().expect("Automatical data type casting failed");
// Data type casting
let _: Array2<f64> = adata.x().get::<ArrayData>().unwrap().unwrap().try_convert().expect("data type casting failed");
}

pub fn test_noncanonical<F, T>(adata_gen: F)
Expand Down
10 changes: 5 additions & 5 deletions anndata/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod datatype;
use crate::data::{DynArray, SelectInfo, SelectInfoElem, Shape};
use crate::data::{ArrayConvert, DynArray, SelectInfo, SelectInfoElem, Shape};
pub use datatype::{BackendData, DataType, ScalarType};

use anyhow::{bail, Result};
Expand Down Expand Up @@ -193,11 +193,11 @@ pub trait DatasetOp<B: Backend + ?Sized> {

fn read_array_slice_cast<T, D, S>(&self, selection: &[S]) -> Result<Array<T, D>>
where
Array<T, D>: TryFrom<DynArray, Error=anyhow::Error>,
DynArray: ArrayConvert<Array<T, D>>,
D: Dimension,
S: AsRef<SelectInfoElem>
{
self.read_dyn_array_slice(selection)?.try_into()
self.read_dyn_array_slice(selection)?.try_convert()
}

fn read_array<T: BackendData, D>(&self) -> Result<Array<T, D>>
Expand All @@ -213,10 +213,10 @@ pub trait DatasetOp<B: Backend + ?Sized> {

fn read_array_cast<T, D>(&self) -> Result<Array<T, D>>
where
Array<T, D>: TryFrom<DynArray, Error=anyhow::Error>,
DynArray: ArrayConvert<Array<T, D>>,
D: Dimension,
{
self.read_dyn_array()?.try_into()
self.read_dyn_array()?.try_convert()
}

fn read_scalar<T: BackendData>(&self) -> Result<T> {
Expand Down
89 changes: 48 additions & 41 deletions anndata/src/data/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub mod utils;

pub use chunks::ArrayChunk;
pub use dataframe::DataFrameIndex;
pub use dense::{CategoricalArray, DynArray, DynCowArray, DynScalar};
pub use dense::{ArrayConvert, CategoricalArray, DynArray, DynCowArray, DynScalar};
pub use slice::{SelectInfo, SelectInfoBounds, SelectInfoElem, SelectInfoElemBounds, Shape};
pub use sparse::{CsrNonCanonical, DynCscMatrix, DynCsrMatrix, DynCsrNonCanonical};

Expand Down Expand Up @@ -118,13 +118,57 @@ impl TryFrom<ArrayData> for DataFrame {
}
}

impl<T, D> TryFrom<ArrayData> for Array<T, D>
where Array<T, D>: TryFrom<DynArray, Error = anyhow::Error>
{
type Error = anyhow::Error;
fn try_from(value: ArrayData) -> Result<Self, Self::Error> {
DynArray::try_from(value)?.try_into()
}
}

impl<T> TryFrom<ArrayData> for CsrMatrix<T>
where CsrMatrix<T>: TryFrom<DynCsrMatrix, Error = anyhow::Error>
{
type Error = anyhow::Error;
fn try_from(value: ArrayData) -> Result<Self, Self::Error> {
DynCsrMatrix::try_from(value)?.try_into()
}
}

impl<T> TryFrom<ArrayData> for CscMatrix<T>
where CscMatrix<T>: TryFrom<DynCscMatrix, Error = anyhow::Error>
{
type Error = anyhow::Error;
fn try_from(value: ArrayData) -> Result<Self, Self::Error> {
DynCscMatrix::try_from(value)?.try_into()
}
}

impl<T> TryFrom<ArrayData> for CsrNonCanonical<T>
where CsrNonCanonical<T>: TryFrom<DynCsrNonCanonical, Error = anyhow::Error>
{
type Error = anyhow::Error;
fn try_from(value: ArrayData) -> Result<Self, Self::Error> {
DynCsrNonCanonical::try_from(value)?.try_into()
}
}

impl<T, D> ArrayConvert<Array<T, D>> for ArrayData
where DynArray: ArrayConvert<Array<T, D>>
{
fn try_convert(self) -> Result<Array<T, D>> {
DynArray::try_from(self)?.try_convert()
}
}

/// macro for implementing From trait for Data from a list of types
macro_rules! impl_into_array_data {
macro_rules! impl_arraydata_traits {
($($ty:ty),*) => {
$(
impl<D: RemoveAxis> From<Array<$ty, D>> for ArrayData {
fn from(data: Array<$ty, D>) -> Self {
ArrayData::Array(data.into_dyn().into())
ArrayData::Array(data.into())
}
}
impl From<CsrMatrix<$ty>> for ArrayData {
Expand All @@ -142,48 +186,11 @@ macro_rules! impl_into_array_data {
ArrayData::CscMatrix(data.into())
}
}
impl<D: RemoveAxis> TryFrom<ArrayData> for Array<$ty, D> {
type Error = anyhow::Error;
fn try_from(value: ArrayData) -> Result<Self, Self::Error> {
match value {
ArrayData::Array(data) => data.try_into(),
_ => bail!("Cannot convert {:?} to {} Array", value.data_type(), stringify!($ty)),
}
}
}
impl TryFrom<ArrayData> for CsrMatrix<$ty> {
type Error = anyhow::Error;
fn try_from(value: ArrayData) -> Result<Self, Self::Error> {
match value {
ArrayData::CsrMatrix(data) => data.try_into(),
_ => bail!("Cannot convert {:?} to {} CsrMatrix", value.data_type(), stringify!($ty)),
}
}
}
impl TryFrom<ArrayData> for CsrNonCanonical<$ty> {
type Error = anyhow::Error;
fn try_from(value: ArrayData) -> Result<Self, Self::Error> {
match value {
ArrayData::CsrNonCanonical(data) => data.try_into(),
ArrayData::CsrMatrix(data) => DynCsrNonCanonical::from(data).try_into(),
_ => bail!("Cannot convert {:?} to {} CsrNonCanonical", value.data_type(), stringify!($ty)),
}
}
}
impl TryFrom<ArrayData> for CscMatrix<$ty> {
type Error = anyhow::Error;
fn try_from(value: ArrayData) -> Result<Self, Self::Error> {
match value {
ArrayData::CscMatrix(data) => data.try_into(),
_ => bail!("Cannot convert {:?} to {} CsrMatrix", value.data_type(), stringify!($ty)),
}
}
}
)*
};
}

impl_into_array_data!(i8, i16, i32, i64, u8, u16, u32, u64, f32, f64, bool, String);
impl_arraydata_traits!(i8, i16, i32, i64, u8, u16, u32, u64, f32, f64, bool, String);

impl WriteData for ArrayData {
fn data_type(&self) -> DataType {
Expand Down
24 changes: 12 additions & 12 deletions anndata/src/data/array/chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,18 @@ impl ArrayChunk for DynArray {
{
let mut iter = iter.peekable();
match iter.peek().context("input iterator is empty")? {
DynArray::U8(_) => ArrayD::<u8>::write_by_chunk(iter.map(|x| x.into_u8().unwrap()), location, name),
DynArray::U16(_) => ArrayD::<u16>::write_by_chunk(iter.map(|x| x.into_u16().unwrap()), location, name),
DynArray::U32(_) => ArrayD::<u32>::write_by_chunk(iter.map(|x| x.into_u32().unwrap()), location, name),
DynArray::U64(_) => ArrayD::<u64>::write_by_chunk(iter.map(|x| x.into_u64().unwrap()), location, name),
DynArray::I8(_) => ArrayD::<i8>::write_by_chunk(iter.map(|x| x.into_i8().unwrap()), location, name),
DynArray::I16(_) => ArrayD::<i16>::write_by_chunk(iter.map(|x| x.into_i16().unwrap()), location, name),
DynArray::I32(_) => ArrayD::<i32>::write_by_chunk(iter.map(|x| x.into_i32().unwrap()), location, name),
DynArray::I64(_) => ArrayD::<i64>::write_by_chunk(iter.map(|x| x.into_i64().unwrap()), location, name),
DynArray::F32(_) => ArrayD::<f32>::write_by_chunk(iter.map(|x| x.into_f32().unwrap()), location, name),
DynArray::F64(_) => ArrayD::<f64>::write_by_chunk(iter.map(|x| x.into_f64().unwrap()), location, name),
DynArray::Bool(_) => ArrayD::<bool>::write_by_chunk(iter.map(|x| x.into_bool().unwrap()), location, name),
DynArray::String(_) => ArrayD::<String>::write_by_chunk(iter.map(|x| x.into_string().unwrap()), location, name),
DynArray::U8(_) => ArrayD::<u8>::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name),
DynArray::U16(_) => ArrayD::<u16>::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name),
DynArray::U32(_) => ArrayD::<u32>::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name),
DynArray::U64(_) => ArrayD::<u64>::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name),
DynArray::I8(_) => ArrayD::<i8>::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name),
DynArray::I16(_) => ArrayD::<i16>::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name),
DynArray::I32(_) => ArrayD::<i32>::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name),
DynArray::I64(_) => ArrayD::<i64>::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name),
DynArray::F32(_) => ArrayD::<f32>::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name),
DynArray::F64(_) => ArrayD::<f64>::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name),
DynArray::Bool(_) => ArrayD::<bool>::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name),
DynArray::String(_) => ArrayD::<String>::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion anndata/src/data/array/dense.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mod dynamic;

pub use dynamic::{DynScalar, DynArray, DynCowArray};
pub use dynamic::{ArrayConvert, DynScalar, DynArray, DynCowArray};

use crate::{
backend::*,
Expand Down
Loading

0 comments on commit 5bc62ce

Please sign in to comment.