diff --git a/anndata-test-utils/src/lib.rs b/anndata-test-utils/src/lib.rs index 2bfe29a..5fd3dac 100644 --- a/anndata-test-utils/src/lib.rs +++ b/anndata-test-utils/src/lib.rs @@ -2,7 +2,7 @@ mod common; pub use common::*; use anndata::{data::CsrNonCanonical, *}; -use data::DynArray; +use data::ArrayConvert; use nalgebra_sparse::{CooMatrix, CsrMatrix}; use ndarray::Array2; use proptest::prelude::*; @@ -68,8 +68,8 @@ where let arr2 = Array2::::zeros((10, 20)); assert!(adata.obsm().add("test", &arr2).is_err()); - // Automatical data type casting - let _: Array2 = adata.x().get::().unwrap().unwrap().try_into().expect("Automatical data type casting failed"); + // Data type casting + let _: Array2 = adata.x().get::().unwrap().unwrap().try_convert().expect("data type casting failed"); } pub fn test_noncanonical(adata_gen: F) diff --git a/anndata/src/backend.rs b/anndata/src/backend.rs index 8546972..337b1dc 100644 --- a/anndata/src/backend.rs +++ b/anndata/src/backend.rs @@ -1,5 +1,5 @@ mod datatype; -use crate::data::{DynArray, SelectInfo, SelectInfoElem, Shape}; +use crate::data::{ArrayConvert, DynArray, SelectInfo, SelectInfoElem, Shape}; pub use datatype::{BackendData, DataType, ScalarType}; use anyhow::{bail, Result}; @@ -193,11 +193,11 @@ pub trait DatasetOp { fn read_array_slice_cast(&self, selection: &[S]) -> Result> where - Array: TryFrom, + DynArray: ArrayConvert>, D: Dimension, S: AsRef { - self.read_dyn_array_slice(selection)?.try_into() + self.read_dyn_array_slice(selection)?.try_convert() } fn read_array(&self) -> Result> @@ -213,10 +213,10 @@ pub trait DatasetOp { fn read_array_cast(&self) -> Result> where - Array: TryFrom, + DynArray: ArrayConvert>, D: Dimension, { - self.read_dyn_array()?.try_into() + self.read_dyn_array()?.try_convert() } fn read_scalar(&self) -> Result { diff --git a/anndata/src/data/array.rs b/anndata/src/data/array.rs index 834f90e..177d0c4 100644 --- a/anndata/src/data/array.rs +++ b/anndata/src/data/array.rs @@ -7,7 +7,7 @@ pub mod utils; pub use chunks::ArrayChunk; pub use dataframe::DataFrameIndex; -pub use dense::{CategoricalArray, DynArray, DynCowArray, DynScalar}; +pub use dense::{ArrayConvert, CategoricalArray, DynArray, DynCowArray, DynScalar}; pub use slice::{SelectInfo, SelectInfoBounds, SelectInfoElem, SelectInfoElemBounds, Shape}; pub use sparse::{CsrNonCanonical, DynCscMatrix, DynCsrMatrix, DynCsrNonCanonical}; @@ -118,13 +118,57 @@ impl TryFrom for DataFrame { } } +impl TryFrom for Array +where Array: TryFrom +{ + type Error = anyhow::Error; + fn try_from(value: ArrayData) -> Result { + DynArray::try_from(value)?.try_into() + } +} + +impl TryFrom for CsrMatrix +where CsrMatrix: TryFrom +{ + type Error = anyhow::Error; + fn try_from(value: ArrayData) -> Result { + DynCsrMatrix::try_from(value)?.try_into() + } +} + +impl TryFrom for CscMatrix +where CscMatrix: TryFrom +{ + type Error = anyhow::Error; + fn try_from(value: ArrayData) -> Result { + DynCscMatrix::try_from(value)?.try_into() + } +} + +impl TryFrom for CsrNonCanonical +where CsrNonCanonical: TryFrom +{ + type Error = anyhow::Error; + fn try_from(value: ArrayData) -> Result { + DynCsrNonCanonical::try_from(value)?.try_into() + } +} + +impl ArrayConvert> for ArrayData +where DynArray: ArrayConvert> +{ + fn try_convert(self) -> Result> { + DynArray::try_from(self)?.try_convert() + } +} + /// macro for implementing From trait for Data from a list of types -macro_rules! impl_into_array_data { +macro_rules! impl_arraydata_traits { ($($ty:ty),*) => { $( impl From> for ArrayData { fn from(data: Array<$ty, D>) -> Self { - ArrayData::Array(data.into_dyn().into()) + ArrayData::Array(data.into()) } } impl From> for ArrayData { @@ -142,48 +186,11 @@ macro_rules! impl_into_array_data { ArrayData::CscMatrix(data.into()) } } - impl TryFrom for Array<$ty, D> { - type Error = anyhow::Error; - fn try_from(value: ArrayData) -> Result { - match value { - ArrayData::Array(data) => data.try_into(), - _ => bail!("Cannot convert {:?} to {} Array", value.data_type(), stringify!($ty)), - } - } - } - impl TryFrom for CsrMatrix<$ty> { - type Error = anyhow::Error; - fn try_from(value: ArrayData) -> Result { - match value { - ArrayData::CsrMatrix(data) => data.try_into(), - _ => bail!("Cannot convert {:?} to {} CsrMatrix", value.data_type(), stringify!($ty)), - } - } - } - impl TryFrom for CsrNonCanonical<$ty> { - type Error = anyhow::Error; - fn try_from(value: ArrayData) -> Result { - match value { - ArrayData::CsrNonCanonical(data) => data.try_into(), - ArrayData::CsrMatrix(data) => DynCsrNonCanonical::from(data).try_into(), - _ => bail!("Cannot convert {:?} to {} CsrNonCanonical", value.data_type(), stringify!($ty)), - } - } - } - impl TryFrom for CscMatrix<$ty> { - type Error = anyhow::Error; - fn try_from(value: ArrayData) -> Result { - match value { - ArrayData::CscMatrix(data) => data.try_into(), - _ => bail!("Cannot convert {:?} to {} CsrMatrix", value.data_type(), stringify!($ty)), - } - } - } )* }; } -impl_into_array_data!(i8, i16, i32, i64, u8, u16, u32, u64, f32, f64, bool, String); +impl_arraydata_traits!(i8, i16, i32, i64, u8, u16, u32, u64, f32, f64, bool, String); impl WriteData for ArrayData { fn data_type(&self) -> DataType { diff --git a/anndata/src/data/array/chunks.rs b/anndata/src/data/array/chunks.rs index ed6fe7a..5dfcc0a 100644 --- a/anndata/src/data/array/chunks.rs +++ b/anndata/src/data/array/chunks.rs @@ -47,18 +47,18 @@ impl ArrayChunk for DynArray { { let mut iter = iter.peekable(); match iter.peek().context("input iterator is empty")? { - DynArray::U8(_) => ArrayD::::write_by_chunk(iter.map(|x| x.into_u8().unwrap()), location, name), - DynArray::U16(_) => ArrayD::::write_by_chunk(iter.map(|x| x.into_u16().unwrap()), location, name), - DynArray::U32(_) => ArrayD::::write_by_chunk(iter.map(|x| x.into_u32().unwrap()), location, name), - DynArray::U64(_) => ArrayD::::write_by_chunk(iter.map(|x| x.into_u64().unwrap()), location, name), - DynArray::I8(_) => ArrayD::::write_by_chunk(iter.map(|x| x.into_i8().unwrap()), location, name), - DynArray::I16(_) => ArrayD::::write_by_chunk(iter.map(|x| x.into_i16().unwrap()), location, name), - DynArray::I32(_) => ArrayD::::write_by_chunk(iter.map(|x| x.into_i32().unwrap()), location, name), - DynArray::I64(_) => ArrayD::::write_by_chunk(iter.map(|x| x.into_i64().unwrap()), location, name), - DynArray::F32(_) => ArrayD::::write_by_chunk(iter.map(|x| x.into_f32().unwrap()), location, name), - DynArray::F64(_) => ArrayD::::write_by_chunk(iter.map(|x| x.into_f64().unwrap()), location, name), - DynArray::Bool(_) => ArrayD::::write_by_chunk(iter.map(|x| x.into_bool().unwrap()), location, name), - DynArray::String(_) => ArrayD::::write_by_chunk(iter.map(|x| x.into_string().unwrap()), location, name), + DynArray::U8(_) => ArrayD::::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name), + DynArray::U16(_) => ArrayD::::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name), + DynArray::U32(_) => ArrayD::::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name), + DynArray::U64(_) => ArrayD::::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name), + DynArray::I8(_) => ArrayD::::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name), + DynArray::I16(_) => ArrayD::::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name), + DynArray::I32(_) => ArrayD::::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name), + DynArray::I64(_) => ArrayD::::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name), + DynArray::F32(_) => ArrayD::::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name), + DynArray::F64(_) => ArrayD::::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name), + DynArray::Bool(_) => ArrayD::::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name), + DynArray::String(_) => ArrayD::::write_by_chunk(iter.map(|x| x.try_into().unwrap()), location, name), } } } diff --git a/anndata/src/data/array/dense.rs b/anndata/src/data/array/dense.rs index fa0e41f..7e63de9 100644 --- a/anndata/src/data/array/dense.rs +++ b/anndata/src/data/array/dense.rs @@ -1,6 +1,6 @@ mod dynamic; -pub use dynamic::{DynScalar, DynArray, DynCowArray}; +pub use dynamic::{ArrayConvert, DynScalar, DynArray, DynCowArray}; use crate::{ backend::*, diff --git a/anndata/src/data/array/dense/dynamic.rs b/anndata/src/data/array/dense/dynamic.rs index 835f93b..f9cc0b3 100644 --- a/anndata/src/data/array/dense/dynamic.rs +++ b/anndata/src/data/array/dense/dynamic.rs @@ -127,17 +127,6 @@ macro_rules! impl_dynarray_into_array{ ($($variant:ident, $scalar_ty:ident),*) => { $( paste! { - /// Extract the concrete array type from the dynamic array. - /// This function does not perform any conversion, it only checks if - /// the underlying data type is exactly the same as the target type. - /// To perform conversion, use the `.try_into()` method. - pub fn [](self) -> Result> { - match self { - DynArray::$variant(x) => Ok(x.into_dimensionality()?), - v => bail!("Cannot convert {} to {}", v.data_type(), stringify!($scalar_ty)), - } - } - pub fn [](&self) -> Result<&ArrayD<$scalar_ty>> { match self { DynArray::$variant(x) => Ok(x), @@ -170,248 +159,33 @@ impl DynArray { ); } -macro_rules! impl_to_dynarray{ - ($($scalar_type:ty, $ident:ident),*) => { +macro_rules! impl_dynarray_traits{ + ($($scalar_ty:ty, $ident:ident),*) => { $( - impl From> for DynArray { - fn from(data: Array<$scalar_type, D>) -> Self { + impl From> for DynArray { + fn from(data: Array<$scalar_ty, D>) -> Self { DynArray::$ident(data.into_dyn()) } } + + impl TryFrom for Array<$scalar_ty, D> { + type Error = anyhow::Error; + fn try_from(arr: DynArray) -> Result { + match arr { + DynArray::$ident(x) => Ok(x.into_dimensionality::()?), + v => bail!("Cannot convert {} to {}", v.data_type(), stringify!($scalar_ty)), + } + } + } )* }; } -impl_to_dynarray!( +impl_dynarray_traits!( i8, I8, i16, I16, i32, I32, i64, I64, u8, U8, u16, U16, u32, U32, u64, U64, f32, F32, f64, F64, bool, Bool, String, String ); -impl TryFrom for Array { - type Error = anyhow::Error; - fn try_from(v: DynArray) -> Result { - match v { - DynArray::I8(data) => Ok(data.into_dimensionality()?), - DynArray::I16(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::I32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::I64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::U8(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::U16(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::U32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::U64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - _ => bail!("Cannot convert to i8 Array"), - } - } -} - -impl TryFrom for Array { - type Error = anyhow::Error; - fn try_from(v: DynArray) -> Result { - match v { - DynArray::I16(data) => Ok(data.into_dimensionality()?), - DynArray::I8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::I32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::I64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::U8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::U16(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::U32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::U64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - _ => bail!("Cannot convert to i16 Array"), - } - } -} - -impl TryFrom for Array { - type Error = anyhow::Error; - fn try_from(v: DynArray) -> Result { - match v { - DynArray::I32(data) => Ok(data.into_dimensionality()?), - DynArray::I8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::I16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::I64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::U8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::U16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::U32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::U64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - _ => bail!("Cannot convert to i32 Array"), - } - } -} - -impl TryFrom for Array { - type Error = anyhow::Error; - fn try_from(v: DynArray) -> Result { - match v { - DynArray::I64(data) => Ok(data.into_dimensionality()?), - DynArray::I8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::I16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::I32(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::U8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::U16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::U32(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::U64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - _ => bail!("Cannot convert to i64 Array"), - } - } -} - -impl TryFrom for Array { - type Error = anyhow::Error; - fn try_from(v: DynArray) -> Result { - match v { - DynArray::U8(data) => Ok(data.into_dimensionality()?), - DynArray::I8(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::I16(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::I32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::I64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::U16(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::U32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::U64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - _ => bail!("Cannot convert to u8 Array"), - } - } -} - -impl TryFrom for Array { - type Error = anyhow::Error; - fn try_from(v: DynArray) -> Result { - match v { - DynArray::U16(data) => Ok(data.into_dimensionality()?), - DynArray::I8(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::I16(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::I32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::I64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::U8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::U32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::U64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - _ => bail!("Cannot convert to u16 Array"), - } - } -} - -impl TryFrom for Array { - type Error = anyhow::Error; - fn try_from(v: DynArray) -> Result { - match v { - DynArray::U32(data) => Ok(data.into_dimensionality()?), - DynArray::I8(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::I16(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::I32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::I64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::U8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::U16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::U64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - _ => bail!("Cannot convert to u32 Array"), - } - } -} - -impl TryFrom for Array { - type Error = anyhow::Error; - fn try_from(v: DynArray) -> Result { - match v { - DynArray::U64(data) => Ok(data.into_dimensionality()?), - DynArray::I8(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::I16(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::I32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::I64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::U8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::U16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::U32(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - _ => bail!("Cannot convert to u64 Array"), - } - } -} - -impl TryFrom for Array { - type Error = anyhow::Error; - fn try_from(v: DynArray) -> Result { - match v { - DynArray::I8(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::I16(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::I32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::I64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::U8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::U16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::U32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::U64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), - DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - _ => bail!("Cannot convert to usize Array"), - } - } -} - -impl TryFrom for Array { - type Error = anyhow::Error; - fn try_from(v: DynArray) -> Result { - match v { - DynArray::F32(data) => Ok(data.into_dimensionality()?), - DynArray::I8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::I16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::U8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::U16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - _ => bail!("Cannot convert to f32 Array"), - } - } -} - -impl TryFrom for Array { - type Error = anyhow::Error; - fn try_from(v: DynArray) -> Result { - match v { - DynArray::F64(data) => Ok(data.into_dimensionality()?), - DynArray::I8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::I16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::I32(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::U8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::U16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::U32(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::F32(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), - _ => bail!("Cannot convert to f64 Array"), - } - } -} - -impl TryFrom for Array { - type Error = anyhow::Error; - fn try_from(v: DynArray) -> Result { - match v { - DynArray::Bool(data) => Ok(data.into_dimensionality()?), - _ => bail!("Cannot convert to bool Array"), - } - } -} - -impl TryFrom for Array { - type Error = anyhow::Error; - fn try_from(v: DynArray) -> Result { - match v { - DynArray::String(data) => Ok(data.into_dimensionality()?), - DynArray::I8(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), - DynArray::I16(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), - DynArray::I32(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), - DynArray::I64(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), - DynArray::U8(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), - DynArray::U16(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), - DynArray::U32(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), - DynArray::U64(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), - DynArray::F32(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), - DynArray::F64(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), - DynArray::Bool(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), - } - } -} - impl Into for DynArray { fn into(self) -> Series { match self { @@ -486,40 +260,40 @@ impl ArrayOp for DynArray { let mut iter = iter.peekable(); match iter.peek().unwrap() { DynArray::U8(_) => { - ArrayD::vstack(iter.map(|x| x.into_u8().unwrap())).map(|x| x.into()) + ArrayD::::vstack(iter.map(|x| x.try_into().unwrap())).map(|x| x.into()) } DynArray::U16(_) => { - ArrayD::vstack(iter.map(|x| x.into_u16().unwrap())).map(|x| x.into()) + ArrayD::::vstack(iter.map(|x| x.try_into().unwrap())).map(|x| x.into()) } DynArray::U32(_) => { - ArrayD::vstack(iter.map(|x| x.into_u32().unwrap())).map(|x| x.into()) + ArrayD::::vstack(iter.map(|x| x.try_into().unwrap())).map(|x| x.into()) } DynArray::U64(_) => { - ArrayD::vstack(iter.map(|x| x.into_u64().unwrap())).map(|x| x.into()) + ArrayD::::vstack(iter.map(|x| x.try_into().unwrap())).map(|x| x.into()) } DynArray::I8(_) => { - ArrayD::vstack(iter.map(|x| x.into_i8().unwrap())).map(|x| x.into()) + ArrayD::::vstack(iter.map(|x| x.try_into().unwrap())).map(|x| x.into()) } DynArray::I16(_) => { - ArrayD::vstack(iter.map(|x| x.into_i16().unwrap())).map(|x| x.into()) + ArrayD::::vstack(iter.map(|x| x.try_into().unwrap())).map(|x| x.into()) } DynArray::I32(_) => { - ArrayD::vstack(iter.map(|x| x.into_i32().unwrap())).map(|x| x.into()) + ArrayD::::vstack(iter.map(|x| x.try_into().unwrap())).map(|x| x.into()) } DynArray::I64(_) => { - ArrayD::vstack(iter.map(|x| x.into_i64().unwrap())).map(|x| x.into()) + ArrayD::::vstack(iter.map(|x| x.try_into().unwrap())).map(|x| x.into()) } DynArray::F32(_) => { - ArrayD::vstack(iter.map(|x| x.into_f32().unwrap())).map(|x| x.into()) + ArrayD::::vstack(iter.map(|x| x.try_into().unwrap())).map(|x| x.into()) } DynArray::F64(_) => { - ArrayD::vstack(iter.map(|x| x.into_f64().unwrap())).map(|x| x.into()) + ArrayD::::vstack(iter.map(|x| x.try_into().unwrap())).map(|x| x.into()) } DynArray::Bool(_) => { - ArrayD::vstack(iter.map(|x| x.into_bool().unwrap())).map(|x| x.into()) + ArrayD::::vstack(iter.map(|x| x.try_into().unwrap())).map(|x| x.into()) } DynArray::String(_) => { - ArrayD::vstack(iter.map(|x| x.into_string().unwrap())).map(|x| x.into()) + ArrayD::::vstack(iter.map(|x| x.try_into().unwrap())).map(|x| x.into()) } } } @@ -635,4 +409,222 @@ impl_dyn_cowarray_convert!(u64, U64); impl_dyn_cowarray_convert!(f32, F32); impl_dyn_cowarray_convert!(f64, F64); impl_dyn_cowarray_convert!(bool, Bool); -impl_dyn_cowarray_convert!(String, String); \ No newline at end of file +impl_dyn_cowarray_convert!(String, String); + +/// `ArrayConvert` trait for converting dynamic arrays to concrete arrays. +/// The `try_convert` method performs the conversion and returns the result. +pub trait ArrayConvert { + fn try_convert(self) -> Result; +} + +impl ArrayConvert> for DynArray { + fn try_convert(self) -> Result> { + match self { + DynArray::I8(data) => Ok(data.into_dimensionality()?), + DynArray::I16(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::I32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::I64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::U8(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::U16(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::U32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::U64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + _ => bail!("Cannot convert to i8 Array"), + } + } +} + +impl ArrayConvert> for DynArray { + fn try_convert(self) -> Result> { + match self { + DynArray::I16(data) => Ok(data.into_dimensionality()?), + DynArray::I8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::I32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::I64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::U8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::U16(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::U32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::U64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + _ => bail!("Cannot convert to i16 Array"), + } + } +} + +impl ArrayConvert> for DynArray { + fn try_convert(self) -> Result> { + match self { + DynArray::I32(data) => Ok(data.into_dimensionality()?), + DynArray::I8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::I16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::I64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::U8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::U16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::U32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::U64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + _ => bail!("Cannot convert to i32 Array"), + } + } +} + +impl ArrayConvert> for DynArray { + fn try_convert(self) -> Result> { + match self { + DynArray::I64(data) => Ok(data.into_dimensionality()?), + DynArray::I8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::I16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::I32(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::U8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::U16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::U32(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::U64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + _ => bail!("Cannot convert to i64 Array"), + } + } +} + +impl ArrayConvert> for DynArray { + fn try_convert(self) -> Result> { + match self { + DynArray::U8(data) => Ok(data.into_dimensionality()?), + DynArray::I8(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::I16(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::I32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::I64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::U16(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::U32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::U64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + _ => bail!("Cannot convert to u8 Array"), + } + } +} + +impl ArrayConvert> for DynArray { + fn try_convert(self) -> Result> { + match self { + DynArray::U16(data) => Ok(data.into_dimensionality()?), + DynArray::I8(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::I16(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::I32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::I64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::U8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::U32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::U64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + _ => bail!("Cannot convert to u16 Array"), + } + } +} + +impl ArrayConvert> for DynArray { + fn try_convert(self) -> Result> { + match self { + DynArray::U32(data) => Ok(data.into_dimensionality()?), + DynArray::I8(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::I16(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::I32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::I64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::U8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::U16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::U64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + _ => bail!("Cannot convert to u32 Array"), + } + } +} + +impl ArrayConvert> for DynArray { + fn try_convert(self) -> Result> { + match self { + DynArray::U64(data) => Ok(data.into_dimensionality()?), + DynArray::I8(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::I16(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::I32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::I64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::U8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::U16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::U32(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + _ => bail!("Cannot convert to u64 Array"), + } + } +} + +impl ArrayConvert> for DynArray { + fn try_convert(self) -> Result> { + match self { + DynArray::I8(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::I16(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::I32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::I64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::U8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::U16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::U32(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::U64(data) => Ok(data.mapv(|x| x.try_into().unwrap()).into_dimensionality()?), + DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + _ => bail!("Cannot convert to usize Array"), + } + } +} + +impl ArrayConvert> for DynArray { + fn try_convert(self) -> Result> { + match self { + DynArray::F32(data) => Ok(data.into_dimensionality()?), + DynArray::I8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::I16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::U8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::U16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + _ => bail!("Cannot convert to f32 Array"), + } + } +} + +impl ArrayConvert> for DynArray { + fn try_convert(self) -> Result> { + match self { + DynArray::F64(data) => Ok(data.into_dimensionality()?), + DynArray::I8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::I16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::I32(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::U8(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::U16(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::U32(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::F32(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + DynArray::Bool(data) => Ok(data.mapv(|x| x.into()).into_dimensionality()?), + _ => bail!("Cannot convert to f64 Array"), + } + } +} + +impl ArrayConvert> for DynArray { + fn try_convert(self) -> Result> { + match self { + DynArray::Bool(data) => Ok(data.into_dimensionality()?), + _ => bail!("Cannot convert to bool Array"), + } + } +} + +impl ArrayConvert> for DynArray { + fn try_convert(self) -> Result> { + match self { + DynArray::String(data) => Ok(data.into_dimensionality()?), + DynArray::I8(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), + DynArray::I16(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), + DynArray::I32(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), + DynArray::I64(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), + DynArray::U8(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), + DynArray::U16(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), + DynArray::U32(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), + DynArray::U64(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), + DynArray::F32(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), + DynArray::F64(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), + DynArray::Bool(data) => Ok(data.mapv(|x| x.to_string()).into_dimensionality()?), + } + } +} \ No newline at end of file diff --git a/anndata/src/data/array/sparse/dynamic.rs b/anndata/src/data/array/sparse/dynamic.rs index 581169b..09dd7d4 100644 --- a/anndata/src/data/array/sparse/dynamic.rs +++ b/anndata/src/data/array/sparse/dynamic.rs @@ -1,11 +1,12 @@ use crate::backend::*; +use crate::data::ArrayConvert; use crate::data::{ array::DynScalar, data_traits::*, slice::{SelectInfoElem, Shape}, }; -use anyhow::{bail, Context, Result}; +use anyhow::{bail, Result}; use nalgebra_sparse::csc::CscMatrix; use nalgebra_sparse::csr::CsrMatrix; use ndarray::Ix1; @@ -27,91 +28,35 @@ pub enum DynCsrMatrix { String(CsrMatrix), } -macro_rules! impl_into_dyn_csr { - ($from_type:ty, $to_type:ident) => { - impl From> for DynCsrMatrix { - fn from(data: CsrMatrix<$from_type>) -> Self { - DynCsrMatrix::$to_type(data) +macro_rules! impl_dyncsr_traits { + ($($scalar_ty:ty, $variant:ident),*) => { + $( + impl From> for DynCsrMatrix { + fn from(data: CsrMatrix<$scalar_ty>) -> Self { + DynCsrMatrix::$variant(data) + } } - } - impl TryFrom for CsrMatrix<$from_type> { - type Error = anyhow::Error; - fn try_from(data: DynCsrMatrix) -> Result { - match data { - DynCsrMatrix::$to_type(data) => Ok(data), - _ => bail!( - "Cannot convert {:?} to {} CsrMatrix", - data.data_type(), - stringify!($from_type) - ), + impl TryFrom for CsrMatrix<$scalar_ty> { + type Error = anyhow::Error; + fn try_from(data: DynCsrMatrix) -> Result { + match data { + DynCsrMatrix::$variant(data) => Ok(data), + _ => bail!( + "Cannot convert {} to {} CsrMatrix", + data.data_type(), + stringify!($scalar_ty) + ), + } } } - } + )* }; } -impl From> for DynCsrMatrix { - fn from(data: CsrMatrix) -> Self { - DynCsrMatrix::U32(data) - } -} - -impl TryFrom for CsrMatrix { - type Error = anyhow::Error; - fn try_from(data: DynCsrMatrix) -> Result { - match data { - DynCsrMatrix::U32(data) => Ok(data), - DynCsrMatrix::I8(data) => Ok(cast_csr(data)?), - DynCsrMatrix::I16(data) => Ok(cast_csr(data)?), - DynCsrMatrix::I32(data) => Ok(cast_csr(data)?), - DynCsrMatrix::I64(data) => Ok(from_i64_csr(data)?), - DynCsrMatrix::U8(data) => Ok(cast_csr(data)?), - DynCsrMatrix::U16(data) => Ok(cast_csr(data)?), - DynCsrMatrix::U64(data) => Ok(cast_csr(data)?), - DynCsrMatrix::F32(_) => bail!("Cannot convert f32 to u32"), - DynCsrMatrix::F64(_) => bail!("Cannot convert f64 to u32"), - DynCsrMatrix::Bool(_) => bail!("Cannot convert bool to f64"), - DynCsrMatrix::String(_) => bail!("Cannot convert string to f64"), - } - } -} - -impl From> for DynCsrMatrix { - fn from(data: CsrMatrix) -> Self { - DynCsrMatrix::F64(data) - } -} - -impl TryFrom for CsrMatrix { - type Error = anyhow::Error; - fn try_from(data: DynCsrMatrix) -> Result { - match data { - DynCsrMatrix::F64(data) => Ok(data), - DynCsrMatrix::I8(data) => Ok(cast_csr(data)?), - DynCsrMatrix::I16(data) => Ok(cast_csr(data)?), - DynCsrMatrix::I32(data) => Ok(cast_csr(data)?), - DynCsrMatrix::I64(data) => Ok(from_i64_csr(data)?), - DynCsrMatrix::U8(data) => Ok(cast_csr(data)?), - DynCsrMatrix::U16(data) => Ok(cast_csr(data)?), - DynCsrMatrix::U32(data) => Ok(cast_csr(data)?), - DynCsrMatrix::U64(_) => bail!("Cannot convert u64 to f64"), - DynCsrMatrix::F32(data) => Ok(cast_csr(data)?), - DynCsrMatrix::Bool(_) => bail!("Cannot convert bool to f64"), - DynCsrMatrix::String(_) => bail!("Cannot convert string to f64"), - } - } -} - -impl_into_dyn_csr!(i8, I8); -impl_into_dyn_csr!(i16, I16); -impl_into_dyn_csr!(i32, I32); -impl_into_dyn_csr!(i64, I64); -impl_into_dyn_csr!(u8, U8); -impl_into_dyn_csr!(u16, U16); -impl_into_dyn_csr!(u64, U64); -impl_into_dyn_csr!(f32, F32); -impl_into_dyn_csr!(bool, Bool); -impl_into_dyn_csr!(String, String); +impl_dyncsr_traits!( + i8, I8, i16, I16, i32, I32, i64, I64, u8, U8, u16, U16, u32, U32, u64, U64, f32, F32, f64, F64, + bool, Bool, String, String +); impl WriteData for DynCsrMatrix { fn data_type(&self) -> DataType { @@ -136,7 +81,7 @@ impl ReadData for DynCsrMatrix { }; } crate::macros::dyn_match!(group.open_dataset("data")?.dtype()?, ScalarType, fun) - }, + } _ => bail!("cannot read csr matrix from non-group container"), } } @@ -237,7 +182,6 @@ impl ReadArrayData for DynCsrMatrix { } } - #[derive(Debug, Clone, PartialEq)] pub enum DynCscMatrix { I8(CscMatrix), @@ -254,91 +198,35 @@ pub enum DynCscMatrix { String(CscMatrix), } -macro_rules! impl_into_dyn_csc { - ($from_type:ty, $to_type:ident) => { - impl From> for DynCscMatrix { - fn from(data: CscMatrix<$from_type>) -> Self { - DynCscMatrix::$to_type(data) +macro_rules! impl_dyncsc_traits { + ($($from_type:ty, $to_type:ident),*) => { + $( + impl From> for DynCscMatrix { + fn from(data: CscMatrix<$from_type>) -> Self { + DynCscMatrix::$to_type(data) + } } - } - impl TryFrom for CscMatrix<$from_type> { - type Error = anyhow::Error; - fn try_from(data: DynCscMatrix) -> Result { - match data { - DynCscMatrix::$to_type(data) => Ok(data), - _ => bail!( - "Cannot convert {:?} to {} CscMatrix", - data.data_type(), - stringify!($from_type) - ), + impl TryFrom for CscMatrix<$from_type> { + type Error = anyhow::Error; + fn try_from(data: DynCscMatrix) -> Result { + match data { + DynCscMatrix::$to_type(data) => Ok(data), + _ => bail!( + "Cannot convert {:?} to {} CscMatrix", + data.data_type(), + stringify!($from_type) + ), + } } } - } + )* }; } -impl From> for DynCscMatrix { - fn from(data: CscMatrix) -> Self { - DynCscMatrix::U32(data) - } -} - -impl TryFrom for CscMatrix { - type Error = anyhow::Error; - fn try_from(data: DynCscMatrix) -> Result { - match data { - DynCscMatrix::U32(data) => Ok(data), - DynCscMatrix::I8(data) => Ok(cast_csc(data)?), - DynCscMatrix::I16(data) => Ok(cast_csc(data)?), - DynCscMatrix::I32(data) => Ok(cast_csc(data)?), - DynCscMatrix::I64(data) => Ok(from_i64_csc(data)?), - DynCscMatrix::U8(data) => Ok(cast_csc(data)?), - DynCscMatrix::U16(data) => Ok(cast_csc(data)?), - DynCscMatrix::U64(data) => Ok(cast_csc(data)?), - DynCscMatrix::F32(_) => bail!("Cannot convert f32 to u32"), - DynCscMatrix::F64(_) => bail!("Cannot convert f64 to u32"), - DynCscMatrix::Bool(_) => bail!("Cannot convert bool to f64"), - DynCscMatrix::String(_) => bail!("Cannot convert string to f64"), - } - } -} - -impl From> for DynCscMatrix { - fn from(data: CscMatrix) -> Self { - DynCscMatrix::F64(data) - } -} - -impl TryFrom for CscMatrix { - type Error = anyhow::Error; - fn try_from(data: DynCscMatrix) -> Result { - match data { - DynCscMatrix::F64(data) => Ok(data), - DynCscMatrix::I8(data) => Ok(cast_csc(data)?), - DynCscMatrix::I16(data) => Ok(cast_csc(data)?), - DynCscMatrix::I32(data) => Ok(cast_csc(data)?), - DynCscMatrix::I64(data) => Ok(from_i64_csc(data)?), - DynCscMatrix::U8(data) => Ok(cast_csc(data)?), - DynCscMatrix::U16(data) => Ok(cast_csc(data)?), - DynCscMatrix::U32(data) => Ok(cast_csc(data)?), - DynCscMatrix::U64(_) => bail!("Cannot convert u64 to f64"), - DynCscMatrix::F32(data) => Ok(cast_csc(data)?), - DynCscMatrix::Bool(_) => bail!("Cannot convert bool to f64"), - DynCscMatrix::String(_) => bail!("Cannot convert string to f64"), - } - } -} - -impl_into_dyn_csc!(i8, I8); -impl_into_dyn_csc!(i16, I16); -impl_into_dyn_csc!(i32, I32); -impl_into_dyn_csc!(i64, I64); -impl_into_dyn_csc!(u8, U8); -impl_into_dyn_csc!(u16, U16); -impl_into_dyn_csc!(u64, U64); -impl_into_dyn_csc!(f32, F32); -impl_into_dyn_csc!(bool, Bool); -impl_into_dyn_csc!(String, String); +impl_dyncsc_traits!( + i8, I8, i16, I16, i32, I32, i64, I64, u8, U8, u16, U16, u32, U32, u64, U64, f32, F32, f64, F64, + bool, Bool, String, String +); impl WriteData for DynCscMatrix { fn data_type(&self) -> DataType { @@ -364,7 +252,7 @@ impl ReadData for DynCscMatrix { }; } crate::macros::dyn_match!(group.open_dataset("data")?.dtype()?, ScalarType, fun) - }, + } _ => bail!("cannot read csc matrix from non-group container"), } } @@ -401,7 +289,11 @@ impl ArrayOp for DynCscMatrix { DynCscMatrix::$variant(CscMatrix::vstack(iter.map(|x| x.try_into().unwrap()))?) }; } - Ok(crate::macros::dyn_map!(iter.peek().unwrap(), DynCscMatrix, fun)) + Ok(crate::macros::dyn_map!( + iter.peek().unwrap(), + DynCscMatrix, + fun + )) } } @@ -435,65 +327,101 @@ impl ReadArrayData for DynCscMatrix { } //////////////////////////////////////////////////////////////////////////////// -// Helper functions +// ArrayConvert implementations //////////////////////////////////////////////////////////////////////////////// -fn cast_csr(csr: CsrMatrix) -> Result> +macro_rules! impl_arrayconvert { + ($($ty:ident, $fun:expr),*) => { + $(paste::paste! { + + impl ArrayConvert<$ty> for [] { + fn try_convert(self) -> Result<$ty> { + match self { + []::U32(data) => Ok(data), + []::I8(data) => $fun(data, |x| Ok(x.try_into()?)), + []::I16(data) => $fun(data, |x| Ok(x.try_into()?)), + []::I32(data) => $fun(data, |x| Ok(x.try_into()?)), + []::I64(data) => $fun(data, |x| Ok(x.try_into()?)), + []::U8(data) => $fun(data, |x| Ok(x.into())), + []::U16(data) => $fun(data, |x| Ok(x.into())), + []::U64(data) => $fun(data, |x| Ok(x.try_into()?)), + []::Bool(data) => $fun(data, |x| Ok(x.into())), + v => bail!("Cannot convert {} to {}", v.data_type(), stringify!($ty)), + } + } + } + + impl ArrayConvert<$ty> for [] { + fn try_convert(self) -> Result<$ty> { + match self { + []::F32(data) => Ok(data), + []::I8(data) => $fun(data, |x| Ok(x.into())), + []::I16(data) => $fun(data, |x| Ok(x.into())), + []::I32(data) => $fun(data, |x| Ok(f32::from_i32(x).unwrap())), + []::I64(data) => $fun(data, |x| Ok(f32::from_i64(x).unwrap())), + []::U8(data) => $fun(data, |x| Ok(x.into())), + []::U16(data) => $fun(data, |x| Ok(x.into())), + []::U32(data) => $fun(data, |x| Ok(f32::from_u32(x).unwrap())), + []::U64(data) => $fun(data, |x| Ok(f32::from_u64(x).unwrap())), + []::F64(data) => $fun(data, |x| Ok(f32::from_f64(x).unwrap())), + []::Bool(data) => $fun(data, |x| Ok(x.into())), + v => bail!("Cannot convert {} to {}", v.data_type(), stringify!($ty)), + } + } + } + + impl ArrayConvert<$ty> for [] { + fn try_convert(self) -> Result<$ty> { + match self { + []::F64(data) => Ok(data), + []::I8(data) => $fun(data, |x| Ok(x.into())), + []::I16(data) => $fun(data, |x| Ok(x.into())), + []::I32(data) => $fun(data, |x| Ok(x.into())), + []::I64(data) => $fun(data, |x| Ok(f64::from_i64(x).unwrap())), + []::U8(data) => $fun(data, |x| Ok(x.into())), + []::U16(data) => $fun(data, |x| Ok(x.into())), + []::U32(data) => $fun(data, |x| Ok(x.into())), + []::U64(data) => $fun(data, |x| Ok(f64::from_u64(x).unwrap())), + []::F32(data) => $fun(data, |x| Ok(x.into())), + []::Bool(data) => $fun(data, |x| Ok(x.into())), + v => bail!("Cannot convert {} to {}", v.data_type(), stringify!($ty)), + } + } + } + + })* + }; +} + +impl_arrayconvert!(CsrMatrix, convert_csr_with, CscMatrix, convert_csc_with); + +fn convert_csr_with(csr: CsrMatrix, f: F) -> Result> where - T: TryInto, - >::Error: std::error::Error + Sync + Send + 'static, + F: Fn(T) -> Result, { let (pattern, values) = csr.into_pattern_and_values(); let out = CsrMatrix::try_from_pattern_and_values( pattern, values .into_iter() - .map(|x| x.try_into()) + .map(|x| f(x)) .collect::>()?, - ) - .unwrap(); - Ok(out) -} - -fn from_i64_csr(csr: CsrMatrix) -> Result> { - let (pattern, values) = csr.into_pattern_and_values(); - let out = CsrMatrix::try_from_pattern_and_values( - pattern, - values - .into_iter() - .map(|x| U::from_i64(x).context("cannot convert from i64")) - .collect::>()?, - ) - .unwrap(); + ).unwrap(); Ok(out) } -fn cast_csc(csc: CscMatrix) -> Result> +fn convert_csc_with(csc: CscMatrix, f: F) -> Result> where - T: TryInto, - >::Error: std::error::Error + Sync + Send + 'static, + F: Fn(T) -> Result, { let (pattern, values) = csc.into_pattern_and_values(); let out = CscMatrix::try_from_pattern_and_values( pattern, values .into_iter() - .map(|x| x.try_into()) + .map(|x| f(x)) .collect::>()?, ) .unwrap(); Ok(out) -} - -fn from_i64_csc(csc: CscMatrix) -> Result> { - let (pattern, values) = csc.into_pattern_and_values(); - let out = CscMatrix::try_from_pattern_and_values( - pattern, - values - .into_iter() - .map(|x| U::from_i64(x).context("cannot convert from i64")) - .collect::>()?, - ) - .unwrap(); - Ok(out) -} +} \ No newline at end of file diff --git a/anndata/src/data/array/sparse/noncanonical.rs b/anndata/src/data/array/sparse/noncanonical.rs index 614b579..75a822a 100644 --- a/anndata/src/data/array/sparse/noncanonical.rs +++ b/anndata/src/data/array/sparse/noncanonical.rs @@ -1,8 +1,8 @@ use crate::backend::*; use crate::data::{ array::utils::{cs_major_index, cs_major_minor_index, cs_major_slice}, - data_traits::*, array::DynScalar, + data_traits::*, slice::{SelectInfoElem, Shape}, SelectInfoBounds, SelectInfoElemBounds, }; @@ -86,143 +86,71 @@ impl DynCsrNonCanonical { } } -macro_rules! impl_into_dyn_csr { - ($from_type:ty, $to_type:ident) => { - impl From> for DynCsrNonCanonical { - fn from(data: CsrNonCanonical<$from_type>) -> Self { - DynCsrNonCanonical::$to_type(data) +macro_rules! impl_noncanonicalcsr_traits { + ($($from_type:ty, $to_type:ident),*) => { + $( + impl From> for DynCsrNonCanonical { + fn from(data: CsrNonCanonical<$from_type>) -> Self { + DynCsrNonCanonical::$to_type(data) + } } - } - impl TryFrom for CsrNonCanonical<$from_type> { - type Error = anyhow::Error; - fn try_from(data: DynCsrNonCanonical) -> Result { - match data { - DynCsrNonCanonical::$to_type(data) => Ok(data), - _ => bail!( - "Cannot convert {:?} to {} CsrNonCanonical", - data.data_type(), - stringify!($from_type) - ), + impl TryFrom for CsrNonCanonical<$from_type> { + type Error = anyhow::Error; + fn try_from(data: DynCsrNonCanonical) -> Result { + match data { + DynCsrNonCanonical::$to_type(data) => Ok(data), + _ => bail!( + "Cannot convert {:?} to {} CsrNonCanonical", + data.data_type(), + stringify!($from_type) + ), + } } } - } + )* }; } -impl_into_dyn_csr!(i8, I8); -impl_into_dyn_csr!(i16, I16); -impl_into_dyn_csr!(i32, I32); -impl_into_dyn_csr!(i64, I64); -impl_into_dyn_csr!(u8, U8); -impl_into_dyn_csr!(u16, U16); -impl_into_dyn_csr!(u32, U32); -impl_into_dyn_csr!(u64, U64); -impl_into_dyn_csr!(f32, F32); -impl_into_dyn_csr!(f64, F64); -impl_into_dyn_csr!(bool, Bool); -impl_into_dyn_csr!(String, String); - -macro_rules! impl_dyn_csr_matrix { - ($self:expr, $fun:ident) => { - match $self { - DynCsrNonCanonical::I8(data) => $fun!(data), - DynCsrNonCanonical::I16(data) => $fun!(data), - DynCsrNonCanonical::I32(data) => $fun!(data), - DynCsrNonCanonical::I64(data) => $fun!(data), - DynCsrNonCanonical::U8(data) => $fun!(data), - DynCsrNonCanonical::U16(data) => $fun!(data), - DynCsrNonCanonical::U32(data) => $fun!(data), - DynCsrNonCanonical::U64(data) => $fun!(data), - DynCsrNonCanonical::F32(data) => $fun!(data), - DynCsrNonCanonical::F64(data) => $fun!(data), - DynCsrNonCanonical::Bool(data) => $fun!(data), - DynCsrNonCanonical::String(data) => $fun!(data), - } - }; -} +impl_noncanonicalcsr_traits!( + i8, I8, i16, I16, i32, I32, i64, I64, u8, U8, u16, U16, u32, U32, u64, U64, f32, F32, f64, F64, + bool, Bool, String, String +); impl From for DynCsrNonCanonical { fn from(value: DynCsrMatrix) -> Self { - match value { - DynCsrMatrix::I8(data) => DynCsrNonCanonical::I8(data.into()), - DynCsrMatrix::I16(data) => DynCsrNonCanonical::I16(data.into()), - DynCsrMatrix::I32(data) => DynCsrNonCanonical::I32(data.into()), - DynCsrMatrix::I64(data) => DynCsrNonCanonical::I64(data.into()), - DynCsrMatrix::U8(data) => DynCsrNonCanonical::U8(data.into()), - DynCsrMatrix::U16(data) => DynCsrNonCanonical::U16(data.into()), - DynCsrMatrix::U32(data) => DynCsrNonCanonical::U32(data.into()), - DynCsrMatrix::U64(data) => DynCsrNonCanonical::U64(data.into()), - DynCsrMatrix::F32(data) => DynCsrNonCanonical::F32(data.into()), - DynCsrMatrix::F64(data) => DynCsrNonCanonical::F64(data.into()), - DynCsrMatrix::Bool(data) => DynCsrNonCanonical::Bool(data.into()), - DynCsrMatrix::String(data) => DynCsrNonCanonical::String(data.into()), + macro_rules! fun { + ($variant:ident, $data:expr) => { + DynCsrNonCanonical::$variant($data.into()) + }; } + crate::macros::dyn_map!(value, DynCsrMatrix, fun) } } impl WriteData for DynCsrNonCanonical { fn data_type(&self) -> DataType { - macro_rules! data_type { - ($data:expr) => { - $data.data_type() - }; - } - impl_dyn_csr_matrix!(self, data_type) + crate::macros::dyn_map_fun!(self, DynCsrNonCanonical, data_type) } + fn write>( &self, location: &G, name: &str, ) -> Result> { - macro_rules! write_data { - ($data:expr) => { - $data.write(location, name) - }; - } - impl_dyn_csr_matrix!(self, write_data) + crate::macros::dyn_map_fun!(self, DynCsrNonCanonical, write, location, name) } } impl ReadData for DynCsrNonCanonical { fn read(container: &DataContainer) -> Result { match container { - DataContainer::Group(group) => match group.open_dataset("data")?.dtype()? { - ScalarType::I8 => { - CsrNonCanonical::::read(container).map(DynCsrNonCanonical::I8) - } - ScalarType::I16 => { - CsrNonCanonical::::read(container).map(DynCsrNonCanonical::I16) - } - ScalarType::I32 => { - CsrNonCanonical::::read(container).map(DynCsrNonCanonical::I32) - } - ScalarType::I64 => { - CsrNonCanonical::::read(container).map(DynCsrNonCanonical::I64) - } - ScalarType::U8 => { - CsrNonCanonical::::read(container).map(DynCsrNonCanonical::U8) - } - ScalarType::U16 => { - CsrNonCanonical::::read(container).map(DynCsrNonCanonical::U16) - } - ScalarType::U32 => { - CsrNonCanonical::::read(container).map(DynCsrNonCanonical::U32) - } - ScalarType::U64 => { - CsrNonCanonical::::read(container).map(DynCsrNonCanonical::U64) - } - ScalarType::F32 => { - CsrNonCanonical::::read(container).map(DynCsrNonCanonical::F32) - } - ScalarType::F64 => { - CsrNonCanonical::::read(container).map(DynCsrNonCanonical::F64) - } - ScalarType::Bool => { - CsrNonCanonical::::read(container).map(DynCsrNonCanonical::Bool) - } - ScalarType::String => { - CsrNonCanonical::::read(container).map(DynCsrNonCanonical::String) + DataContainer::Group(group) => { + macro_rules! fun { + ($variant:ident) => { + CsrNonCanonical::read(container).map(DynCsrNonCanonical::$variant) + }; } + crate::macros::dyn_match!(group.open_dataset("data")?.dtype()?, ScalarType, fun) }, _ => bail!("cannot read csr matrix from non-group container"), } @@ -231,35 +159,25 @@ impl ReadData for DynCsrNonCanonical { impl HasShape for DynCsrNonCanonical { fn shape(&self) -> Shape { - macro_rules! shape { - ($data:expr) => { - $data.shape() - }; - } - impl_dyn_csr_matrix!(self, shape) + crate::macros::dyn_map_fun!(self, DynCsrNonCanonical, shape) } } impl ArrayOp for DynCsrNonCanonical { fn get(&self, index: &[usize]) -> Option { - macro_rules! get { - ($data:expr) => { - $data.get(index) - }; - } - impl_dyn_csr_matrix!(self, get) + crate::macros::dyn_map_fun!(self, DynCsrNonCanonical, get, index) } fn select(&self, info: &[S]) -> Self where S: AsRef, { - macro_rules! select { - ($data:expr) => { + macro_rules! fun { + ($variant:ident, $data:expr) => { $data.select(info).into() }; } - impl_dyn_csr_matrix!(self, select) + crate::macros::dyn_map!(self, DynCsrNonCanonical, fun) } fn vstack>(iter: I) -> Result { @@ -324,44 +242,12 @@ impl ReadArrayData for DynCsrNonCanonical { S: AsRef, { if let DataType::CsrMatrix(ty) = container.encoding_type()? { - match ty { - ScalarType::I8 => { - CsrNonCanonical::::read_select(container, info).map(Into::into) - } - ScalarType::I16 => { - CsrNonCanonical::::read_select(container, info).map(Into::into) - } - ScalarType::I32 => { - CsrNonCanonical::::read_select(container, info).map(Into::into) - } - ScalarType::I64 => { - CsrNonCanonical::::read_select(container, info).map(Into::into) - } - ScalarType::U8 => { - CsrNonCanonical::::read_select(container, info).map(Into::into) - } - ScalarType::U16 => { - CsrNonCanonical::::read_select(container, info).map(Into::into) - } - ScalarType::U32 => { - CsrNonCanonical::::read_select(container, info).map(Into::into) - } - ScalarType::U64 => { - CsrNonCanonical::::read_select(container, info).map(Into::into) - } - ScalarType::F32 => { - CsrNonCanonical::::read_select(container, info).map(Into::into) - } - ScalarType::F64 => { - CsrNonCanonical::::read_select(container, info).map(Into::into) - } - ScalarType::Bool => { - CsrNonCanonical::::read_select(container, info).map(Into::into) - } - ScalarType::String => { - CsrNonCanonical::::read_select(container, info).map(Into::into) - } + macro_rules! fun { + ($variant:ident) => { + CsrNonCanonical::read_select(container, info).map(DynCsrNonCanonical::$variant) + }; } + crate::macros::dyn_match!(ty, ScalarType, fun) } else { bail!("the container does not contain a csr matrix"); } @@ -755,7 +641,14 @@ impl WriteData for CsrNonCanonical { group.new_str_attr("encoding-type", "csr_matrix")?; group.new_str_attr("encoding-version", "0.1.0")?; - group.new_array_attr("shape", &shape.as_ref().iter().map(|x| *x as u64).collect::>())?; + group.new_array_attr( + "shape", + &shape + .as_ref() + .iter() + .map(|x| *x as u64) + .collect::>(), + )?; group.new_array_dataset("data", self.values().into(), Default::default())?; @@ -848,7 +741,11 @@ impl ReadData for CsrNonCanonical { .into_raw_vec_and_offset() .0; Ok(Self::from_csr_data( - shape[0] as usize, shape[1] as usize, indptr, indices, data, + shape[0] as usize, + shape[1] as usize, + indptr, + indices, + data, )) } } diff --git a/anndata/src/traits.rs b/anndata/src/traits.rs index 589ea4a..1912e21 100644 --- a/anndata/src/traits.rs +++ b/anndata/src/traits.rs @@ -669,11 +669,7 @@ impl AxisArraysOp for &StackedAxisArrays { self.data.get(key).cloned() } - fn add>( - &self, - _key: &str, - _data: D, - ) -> Result<()> { + fn add>(&self, _key: &str, _data: D) -> Result<()> { todo!() }