From 65312c8f226635d836c42b6fa0b7f05a3edc10e5 Mon Sep 17 00:00:00 2001 From: Kai Zhang Date: Mon, 9 Dec 2024 17:26:35 +0800 Subject: [PATCH] minor fix --- snapatac2-python/src/embedding.rs | 42 ++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/snapatac2-python/src/embedding.rs b/snapatac2-python/src/embedding.rs index fe147d3d..5ac2cc40 100644 --- a/snapatac2-python/src/embedding.rs +++ b/snapatac2-python/src/embedding.rs @@ -1,7 +1,9 @@ use crate::utils::AnnDataLike; use anndata::{ - data::{array::utils::to_csr_data, SelectInfoElem, SelectInfoElemBounds}, + data::{ + array::utils::to_csr_data, ArrayConvert, DynCsrMatrix, SelectInfoElem, SelectInfoElemBounds, + }, AnnDataOp, ArrayData, ArrayElemOp, Backend, Selectable, }; use anndata_hdf5::H5; @@ -32,7 +34,11 @@ pub(crate) fn spectral_embedding<'py>( macro_rules! run { ($data:expr) => {{ let slice = pyanndata::data::to_select_elem(selected_features, $data.n_vars())?; - let mut mat: CsrMatrix = $data.x().slice_axis(1, slice)?.unwrap(); + let mut mat: CsrMatrix = $data + .x() + .slice_axis::(1, slice)? + .unwrap() + .try_convert()?; if let Some(weights) = feature_weights { normalize(&mut mat, &weights); } else { @@ -71,12 +77,10 @@ pub(crate) fn spectral_embedding_nystrom<'py>( let weights = if let Some(weights) = feature_weights { weights } else { - idf_from_chunks( - $data - .x() - .iter(5000) - .map(|x: (CsrMatrix, _, _)| x.0.select_axis(1, &selected_features)), - ) + idf_from_chunks($data.x().iter(5000).map(|x: (DynCsrMatrix, _, _)| { + let mat: CsrMatrix = x.0.try_convert().unwrap(); + mat.select_axis(1, &selected_features) + })) }; let n_obs = $data.n_obs(); @@ -91,8 +95,9 @@ pub(crate) fn spectral_embedding_nystrom<'py>( let selected_samples = SelectInfoElem::from(idx); let mut seed_mat: CsrMatrix = $data .x() - .slice(&[selected_samples, selected_features.clone()])? - .unwrap(); + .slice::(&[selected_samples, selected_features.clone()])? + .unwrap() + .try_convert()?; // feature weighting and L2 norm normalization. normalize(&mut seed_mat, &weights); @@ -106,8 +111,9 @@ pub(crate) fn spectral_embedding_nystrom<'py>( &v, &mut u, &d, - $data.x().iter(chunk_size).map(|x: (CsrMatrix, _, _)| { - let mut mat = x.0.select_axis(1, &selected_features); + $data.x().iter(chunk_size).map(|x: (DynCsrMatrix, _, _)| { + let mut mat: CsrMatrix = x.0.try_convert().unwrap(); + mat = mat.select_axis(1, &selected_features); normalize(&mut mat, &weights); mat }), @@ -175,7 +181,8 @@ fn spectral_mf( random_state, ); let result = fun.call1(py, args)?; - let (evals, evecs): (PyReadonlyArray1<'_, f64>, PyReadonlyArray2<'_, f64>) = result.extract(py)?; + let (evals, evecs): (PyReadonlyArray1<'_, f64>, PyReadonlyArray2<'_, f64>) = + result.extract(py)?; anyhow::Ok((evals.to_owned_array(), evecs.to_owned_array())) })?; @@ -396,8 +403,13 @@ pub(crate) fn multi_spectral_embedding<'py>( ($data:expr) => {{ let slice = pyanndata::data::to_select_elem(&s, $data.n_vars()) .expect("Invalid feature selection"); - let mut mat: CsrMatrix = - $data.x().slice_axis(1, slice).unwrap().expect("X is None"); + let mut mat: CsrMatrix = $data + .x() + .slice_axis::(1, slice) + .unwrap() + .expect("X is None") + .try_convert() + .unwrap(); let feature_weights = idf(&mat); // feature weighting and L2 norm normalization.