Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add serde for some matrix types #5

Merged
merged 2 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion air/src/two_row_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ impl<'a, T> Matrix<T> for TwoRowMatrixView<'a, T> {
}

impl<T: Clone + core::fmt::Debug> MatrixRows<T> for TwoRowMatrixView<'_, T> {
type Row<'a> = Cloned<slice::Iter<'a, T>> where Self: 'a, T: 'a;
type Row<'a>
= Cloned<slice::Iter<'a, T>>
where
Self: 'a,
T: 'a;

fn row(&self, r: usize) -> Self::Row<'_> {
let slice = match r {
Expand Down
10 changes: 8 additions & 2 deletions commit/src/adapters/extension_mmcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ where
type Commitment = InnerMmcs::Commitment;
type Proof = InnerMmcs::Proof;
type Error = InnerMmcs::Error;
type Mat<'a> = ExtensionMatrix<F, EF, InnerMmcs::Mat<'a>> where Self: 'a;
type Mat<'a>
= ExtensionMatrix<F, EF, InnerMmcs::Mat<'a>>
where
Self: 'a;

fn open_batch(
&self,
Expand Down Expand Up @@ -131,7 +134,10 @@ where
EF: ExtensionField<F>,
InnerMat: MatrixRows<F>,
{
type Row<'a> = ExtensionRow<F, EF, <<InnerMat as MatrixRows<F>>::Row<'a> as IntoIterator>::IntoIter> where Self: 'a;
type Row<'a>
= ExtensionRow<F, EF, <<InnerMat as MatrixRows<F>>::Row<'a> as IntoIterator>::IntoIter>
where
Self: 'a;

fn row(&self, r: usize) -> Self::Row<'_> {
ExtensionRow {
Expand Down
5 changes: 4 additions & 1 deletion fri/src/two_adic_pcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,10 @@ where
<C::FriMmcs as Mmcs<C::Challenge>>::ProverData: Send + Sync,
<C::InputMmcs as Mmcs<C::Val>>::ProverData: Send + Sync + Sized,
{
type Lde<'a> = BitReversedMatrixView<<C::InputMmcs as Mmcs<C::Val>>::Mat<'a>> where Self: 'a;
type Lde<'a>
= BitReversedMatrixView<<C::InputMmcs as Mmcs<C::Val>>::Mat<'a>>
where
Self: 'a;

fn coset_shift(&self) -> C::Val {
C::Val::generator()
Expand Down
4 changes: 4 additions & 0 deletions matrix/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@ version = "0.1.0"
edition = "2021"
license = "MIT OR Apache-2.0"

[features]
serde = ["dep:serde"]

[dependencies]
p3-field = { path = "../field" }
p3-maybe-rayon = { path = "../maybe-rayon" }
p3-util = { path = "../util" }
rand = "0.8.5"
serde = { version = "1.0.210", optional = true }

[dev-dependencies]
criterion = "0.5.1"
Expand Down
5 changes: 4 additions & 1 deletion matrix/src/bitrev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ impl<T, Inner: MatrixGet<T>> MatrixGet<T> for BitReversedMatrixView<Inner> {
}

impl<T: core::fmt::Debug, Inner: MatrixRows<T>> MatrixRows<T> for BitReversedMatrixView<Inner> {
type Row<'a> = Inner::Row<'a> where Self: 'a;
type Row<'a>
= Inner::Row<'a>
where
Self: 'a;

fn row(&self, r: usize) -> Self::Row<'_> {
self.inner.row(reverse_bits_len(r, self.log_height))
Expand Down
20 changes: 17 additions & 3 deletions matrix/src/dense.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use p3_field::{ExtensionField, Field, PackedField};
use p3_maybe_rayon::prelude::*;
use rand::distributions::{Distribution, Standard};
use rand::Rng;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

use crate::{Matrix, MatrixGet, MatrixRowSlices, MatrixRowSlicesMut, MatrixRows, MatrixTranspose};

Expand All @@ -16,6 +18,7 @@ const TRANSPOSE_BLOCK_SIZE: usize = 64;

/// A dense matrix stored in row-major form.
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct RowMajorMatrix<T: Debug> {
/// All values, stored in row-major order.
pub values: Vec<T>,
Expand Down Expand Up @@ -201,7 +204,10 @@ impl<T: Clone + core::fmt::Debug> MatrixGet<T> for RowMajorMatrix<T> {
}

impl<T: Clone + core::fmt::Debug> MatrixRows<T> for RowMajorMatrix<T> {
type Row<'a> = Cloned<slice::Iter<'a, T>> where T: 'a;
type Row<'a>
= Cloned<slice::Iter<'a, T>>
where
T: 'a;

fn row(&self, r: usize) -> Self::Row<'_> {
self.row_slice(r).iter().cloned()
Expand Down Expand Up @@ -284,7 +290,11 @@ impl<T: Clone> MatrixGet<T> for RowMajorMatrixView<'_, T> {
}

impl<T: Clone + core::fmt::Debug> MatrixRows<T> for RowMajorMatrixView<'_, T> {
type Row<'a> = Cloned<slice::Iter<'a, T>> where Self: 'a, T: 'a;
type Row<'a>
= Cloned<slice::Iter<'a, T>>
where
Self: 'a,
T: 'a;

fn row(&self, r: usize) -> Self::Row<'_> {
self.row_slice(r).iter().cloned()
Expand Down Expand Up @@ -420,7 +430,11 @@ impl<T: Clone> MatrixGet<T> for RowMajorMatrixViewMut<'_, T> {
}

impl<T: Clone + core::fmt::Debug> MatrixRows<T> for RowMajorMatrixViewMut<'_, T> {
type Row<'a> = Cloned<slice::Iter<'a, T>> where Self: 'a, T: 'a;
type Row<'a>
= Cloned<slice::Iter<'a, T>>
where
Self: 'a,
T: 'a;

fn row(&self, r: usize) -> Self::Row<'_> {
self.row_slice(r).iter().cloned()
Expand Down
6 changes: 5 additions & 1 deletion matrix/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ extern crate alloc;
use alloc::vec::Vec;
use core::fmt::{Debug, Display, Formatter};

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

use crate::dense::RowMajorMatrix;
use crate::strided::VerticallyStridedMatrixView;

Expand All @@ -31,7 +34,8 @@ pub trait Matrix<T> {
}
}

#[derive(Clone, Copy)]
#[derive(Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct Dimensions {
pub width: usize,
pub height: usize,
Expand Down
4 changes: 4 additions & 0 deletions matrix/src/sparse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@ use core::ops::Range;

use rand::distributions::{Distribution, Standard};
use rand::Rng;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

use crate::Matrix;

/// A sparse matrix stored in the compressed sparse row format.
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct CsrMatrix<T> {
width: usize,

Expand Down
5 changes: 4 additions & 1 deletion matrix/src/stack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ where
Second: MatrixRows<T>,
T: core::fmt::Debug,
{
type Row<'a> = EitherIterable<First::Row<'a>, Second::Row<'a>> where Self: 'a;
type Row<'a>
= EitherIterable<First::Row<'a>, Second::Row<'a>>
where
Self: 'a;

fn row(&self, r: usize) -> Self::Row<'_> {
if r < self.first.height() {
Expand Down
5 changes: 4 additions & 1 deletion matrix/src/strided.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ impl<T, Inner: MatrixGet<T>> MatrixGet<T> for VerticallyStridedMatrixView<Inner>
impl<T: core::fmt::Debug, Inner: MatrixRows<T>> MatrixRows<T>
for VerticallyStridedMatrixView<Inner>
{
type Row<'a> = Inner::Row<'a> where Self: 'a;
type Row<'a>
= Inner::Row<'a>
where
Self: 'a;

fn row(&self, r: usize) -> Self::Row<'_> {
self.inner.row(r * self.stride + self.offset)
Expand Down
6 changes: 5 additions & 1 deletion merkle-tree/src/mmcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ where
type Commitment = [P::Scalar; DIGEST_ELEMS];
type Proof = Vec<[P::Scalar; DIGEST_ELEMS]>;
type Error = ();
type Mat<'a> = RowMajorMatrixView<'a, P::Scalar> where H: 'a, C: 'a;
type Mat<'a>
= RowMajorMatrixView<'a, P::Scalar>
where
H: 'a,
C: 'a;

fn open_batch(
&self,
Expand Down
6 changes: 5 additions & 1 deletion tensor-pcs/src/wrapped_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ where
M: MatrixRows<T>,
T: core::fmt::Debug,
{
type Row<'a> = WrappedMatrixRow<'a, T, M> where T: 'a, M: 'a;
type Row<'a>
= WrappedMatrixRow<'a, T, M>
where
T: 'a,
M: 'a;

fn row(&self, r: usize) -> Self::Row<'_> {
WrappedMatrixRow {
Expand Down
6 changes: 5 additions & 1 deletion uni-stark/src/public.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ impl<T> Matrix<T> for PublicRow<T> {
}

impl<T: Clone + core::fmt::Debug> MatrixRows<T> for PublicRow<T> {
type Row<'a> = Cloned<slice::Iter<'a, T>> where T: 'a, Self: 'a;
type Row<'a>
= Cloned<slice::Iter<'a, T>>
where
T: 'a,
Self: 'a;

fn row(&self, _r: usize) -> Self::Row<'_> {
self.0.iter().cloned()
Expand Down
Loading