Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
kaizhang committed Dec 9, 2024
1 parent 2c2e16b commit 65312c8
Showing 1 changed file with 27 additions and 15 deletions.
42 changes: 27 additions & 15 deletions snapatac2-python/src/embedding.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::utils::AnnDataLike;

use anndata::{
data::{array::utils::to_csr_data, SelectInfoElem, SelectInfoElemBounds},
data::{
array::utils::to_csr_data, ArrayConvert, DynCsrMatrix, SelectInfoElem, SelectInfoElemBounds,
},
AnnDataOp, ArrayData, ArrayElemOp, Backend, Selectable,
};
use anndata_hdf5::H5;
Expand Down Expand Up @@ -32,7 +34,11 @@ pub(crate) fn spectral_embedding<'py>(
macro_rules! run {
($data:expr) => {{
let slice = pyanndata::data::to_select_elem(selected_features, $data.n_vars())?;
let mut mat: CsrMatrix<f64> = $data.x().slice_axis(1, slice)?.unwrap();
let mut mat: CsrMatrix<f64> = $data
.x()
.slice_axis::<DynCsrMatrix, _>(1, slice)?
.unwrap()
.try_convert()?;
if let Some(weights) = feature_weights {
normalize(&mut mat, &weights);
} else {
Expand Down Expand Up @@ -71,12 +77,10 @@ pub(crate) fn spectral_embedding_nystrom<'py>(
let weights = if let Some(weights) = feature_weights {
weights
} else {
idf_from_chunks(
$data
.x()
.iter(5000)
.map(|x: (CsrMatrix<f64>, _, _)| x.0.select_axis(1, &selected_features)),
)
idf_from_chunks($data.x().iter(5000).map(|x: (DynCsrMatrix, _, _)| {
let mat: CsrMatrix<f64> = x.0.try_convert().unwrap();
mat.select_axis(1, &selected_features)
}))
};

let n_obs = $data.n_obs();
Expand All @@ -91,8 +95,9 @@ pub(crate) fn spectral_embedding_nystrom<'py>(
let selected_samples = SelectInfoElem::from(idx);
let mut seed_mat: CsrMatrix<f64> = $data
.x()
.slice(&[selected_samples, selected_features.clone()])?
.unwrap();
.slice::<DynCsrMatrix, _>(&[selected_samples, selected_features.clone()])?
.unwrap()
.try_convert()?;

// feature weighting and L2 norm normalization.
normalize(&mut seed_mat, &weights);
Expand All @@ -106,8 +111,9 @@ pub(crate) fn spectral_embedding_nystrom<'py>(
&v,
&mut u,
&d,
$data.x().iter(chunk_size).map(|x: (CsrMatrix<f64>, _, _)| {
let mut mat = x.0.select_axis(1, &selected_features);
$data.x().iter(chunk_size).map(|x: (DynCsrMatrix, _, _)| {
let mut mat: CsrMatrix<f64> = x.0.try_convert().unwrap();
mat = mat.select_axis(1, &selected_features);
normalize(&mut mat, &weights);
mat
}),
Expand Down Expand Up @@ -175,7 +181,8 @@ fn spectral_mf(
random_state,
);
let result = fun.call1(py, args)?;
let (evals, evecs): (PyReadonlyArray1<'_, f64>, PyReadonlyArray2<'_, f64>) = result.extract(py)?;
let (evals, evecs): (PyReadonlyArray1<'_, f64>, PyReadonlyArray2<'_, f64>) =
result.extract(py)?;

anyhow::Ok((evals.to_owned_array(), evecs.to_owned_array()))
})?;
Expand Down Expand Up @@ -396,8 +403,13 @@ pub(crate) fn multi_spectral_embedding<'py>(
($data:expr) => {{
let slice = pyanndata::data::to_select_elem(&s, $data.n_vars())
.expect("Invalid feature selection");
let mut mat: CsrMatrix<f64> =
$data.x().slice_axis(1, slice).unwrap().expect("X is None");
let mut mat: CsrMatrix<f64> = $data
.x()
.slice_axis::<DynCsrMatrix, _>(1, slice)
.unwrap()
.expect("X is None")
.try_convert()
.unwrap();
let feature_weights = idf(&mat);

// feature weighting and L2 norm normalization.
Expand Down

0 comments on commit 65312c8

Please sign in to comment.