Skip to content

Commit

Permalink
zal: Add a caching API
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Jan 12, 2024
1 parent 3f40180 commit d3f1633
Showing 1 changed file with 69 additions and 0 deletions.
69 changes: 69 additions & 0 deletions src/zal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,47 @@ pub trait ZalEngine: Debug {}

pub trait MsmAccel<C: CurveAffine>: ZalEngine {
fn msm(&self, coeffs: &[C::Scalar], base: &[C]) -> C::Curve;

// Caching API
// -------------------------------------------------
// From here we propose an extended API
// that allows reusing coeffs and/or the base points
//
// This is inspired by CuDNN API (Nvidia GPU)
// and oneDNN API (CPU, OpenCL) https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnn-ops-infer-so-opaque
// usage of descriptors
//
// https://github.com/oneapi-src/oneDNN/blob/master/doc/programming_model/basic_concepts.md
//
// Descriptors are opaque pointers that hold the input in a format suitable for the accelerator engine.
// They may be:
// - Input moved on accelerator device (only once for repeated calls)
// - Endianess conversion
// - Converting from Montgomery to Canonical form
// - Input changed from Projective to Jacobian coordinates or even to a Twisted Edwards curve.
// - other form of expensive preprocessing
type CoeffsDescriptor<'c>;
type BaseDescriptor<'b>;

fn get_coeffs_descriptor<'c>(&self, coeffs: &'c [C::Scalar]) -> Self::CoeffsDescriptor<'c>;
fn get_base_descriptor<'b>(&self, base: &'b [C]) -> Self::BaseDescriptor<'b>;

fn msm_with_cached_scalars(&self, coeffs: &Self::CoeffsDescriptor<'_>, base: &[C]) -> C::Curve;

fn msm_with_cached_base(&self, coeffs: &[C::Scalar], base: &Self::BaseDescriptor<'_>) -> C::Curve;

fn msm_with_cached_inputs(&self, coeffs: &Self::CoeffsDescriptor<'_>, base: &Self::BaseDescriptor<'_>) -> C::Curve;
// Execute MSM according to descriptors
// Unsure of naming, msm_with_cached_inputs, msm_apply, msm_cached, msm_with_descriptors, ...
}

// ZAL using Halo2curves as a backend
// ---------------------------------------------------

#[derive(Debug)]
pub struct H2cEngine;
pub struct H2cMsmCoeffsDesc<'c, C: CurveAffine> { raw: &'c [C::Scalar]}
pub struct H2cMsmBaseDesc<'b, C: CurveAffine> { raw: &'b [C]}

impl H2cEngine {
pub fn new() -> Self {
Expand All @@ -69,6 +103,32 @@ impl<C: CurveAffine> MsmAccel<C> for H2cEngine {
#[allow(deprecated)]
best_multiexp(coeffs, bases)
}

// Caching API
// -------------------------------------------------

type CoeffsDescriptor<'c> = H2cMsmCoeffsDesc<'c, C>;
type BaseDescriptor<'b> = H2cMsmBaseDesc<'b, C>;

fn get_coeffs_descriptor<'c>(&self, coeffs: &'c [C::Scalar]) -> Self::CoeffsDescriptor<'c>{
// Do expensive device/library specific preprocessing here
Self::CoeffsDescriptor { raw: coeffs }
}
fn get_base_descriptor<'b>(&self, base: &'b [C]) -> Self::BaseDescriptor<'b> {
Self::BaseDescriptor { raw: base }
}

fn msm_with_cached_scalars(&self, coeffs: &Self::CoeffsDescriptor<'_>, base: &[C]) -> C::Curve {
best_multiexp(coeffs.raw, base)
}

fn msm_with_cached_base(&self, coeffs: &[C::Scalar], base: &Self::BaseDescriptor<'_>) -> C::Curve {
best_multiexp(coeffs, base.raw)
}

fn msm_with_cached_inputs(&self, coeffs: &Self::CoeffsDescriptor<'_>, base: &Self::BaseDescriptor<'_>) -> C::Curve {
best_multiexp(coeffs.raw, base.raw)
}
}

impl Default for H2cEngine {
Expand Down Expand Up @@ -117,6 +177,15 @@ mod test {
end_timer!(t1);

assert_eq!(e0, e1);

// Caching API
// -----------
let t2 = start_timer!(|| format!("H2cEngine msm cached base k={}", k));
let base_descriptor = engine.get_base_descriptor(points);
let e2 = engine.msm_with_cached_base(scalars, &base_descriptor);
end_timer!(t2);

assert_eq!(e0, e2)
}
}

Expand Down

0 comments on commit d3f1633

Please sign in to comment.