Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement missing traits to allow Scalar to be used as generic type for polynomial::Polynomial #411

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ serde = { version = "1.0", default-features = false, optional = true, features =
packed_simd = { version = "0.3.4", package = "packed_simd_2", features = ["into_bits"], optional = true }
zeroize = { version = ">=1, <1.4", default-features = false }
fiat-crypto = { version = "0.1.6", optional = true}
num-traits = "0.2"
hex = "0.4.3"

[features]
nightly = ["subtle/nightly"]
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,9 @@ extern crate packed_simd;
extern crate byteorder;
pub extern crate digest;
extern crate rand_core;
extern crate num_traits;
extern crate zeroize;
extern crate hex;

#[cfg(any(feature = "fiat_u64_backend", feature = "fiat_u32_backend"))]
extern crate fiat_crypto;
Expand Down
21 changes: 20 additions & 1 deletion src/ristretto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,14 @@
//! https://ristretto.group/

use core::borrow::Borrow;
use core::fmt::Debug;
use core::fmt::{Debug, Display};
use core::iter::Sum;
use core::ops::{Add, Neg, Sub};
use core::ops::{AddAssign, SubAssign};
use core::ops::{Mul, MulAssign};

use::num_traits::Zero;

use rand_core::{CryptoRng, RngCore};

use digest::generic_array::typenum::U64;
Expand Down Expand Up @@ -422,6 +424,23 @@ impl<'de> Deserialize<'de> for CompressedRistretto {
}
}

impl Zero for RistrettoPoint {
fn zero() -> Self {
RistrettoPoint::identity()
}
fn is_zero(&self) -> bool {
self == &RistrettoPoint::identity()
}
}

impl Display for RistrettoPoint {
fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
let data = hex::encode(self.compress().as_bytes());
write!(f, "{}", data)?;
Ok(())
}
}

// ------------------------------------------------------------------------
// Internal point representations
// ------------------------------------------------------------------------
Expand Down
51 changes: 47 additions & 4 deletions src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,18 @@
//!
//! The resulting `Scalar` has exactly the specified bit pattern,
//! **except for the highest bit, which will be set to 0**.

use core::borrow::Borrow;
use core::cmp::{Eq, PartialEq};
use core::fmt::Debug;
use core::cmp::{Eq, Ord, PartialEq};
use core::fmt::{Debug, Display};
use core::iter::{Product, Sum};
use core::ops::Index;
use core::ops::Neg;
use core::ops::{Add, AddAssign};
use core::ops::{Mul, MulAssign};
use core::ops::{Sub, SubAssign};
use core::ops::{Div, DivAssign};

use::num_traits::{Zero, One};

#[allow(unused_imports)]
use prelude::*;
Expand Down Expand Up @@ -192,7 +194,7 @@ type UnpackedScalar = backend::serial::u32::scalar::Scalar29;

/// The `Scalar` struct holds an integer \\(s < 2\^{255} \\) which
/// represents an element of \\(\mathbb Z / \ell\\).
#[derive(Copy, Clone, Hash)]
#[derive(Copy, Clone, Hash, Ord, PartialOrd)]
pub struct Scalar {
/// `bytes` is a little-endian byte encoding of an integer representing a scalar modulo the
/// group order.
Expand Down Expand Up @@ -268,6 +270,14 @@ impl Debug for Scalar {
}
}

impl Display for Scalar {
fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
let data = hex::encode(self.bytes);
write!(f, "{}", data)?;
Ok(())
}
}

impl Eq for Scalar {}
impl PartialEq for Scalar {
fn eq(&self, other: &Self) -> bool {
Expand All @@ -290,6 +300,39 @@ impl Index<usize> for Scalar {
}
}

impl Zero for Scalar {
fn zero() -> Self {
Scalar::zero()
}
fn is_zero(&self) -> bool {
self == &Scalar::zero()
}
}

impl One for Scalar {
fn one() -> Self {
Scalar::one()
}
fn is_one(&self) -> bool {
self == &Scalar::one()
}
}

impl Div<Scalar> for Scalar {
type Output = Scalar;
fn div(self, q: Scalar) -> Self::Output {
let q1 = q.invert();
self * q1
}
}

impl DivAssign<Scalar> for Scalar {
fn div_assign(&mut self, q: Scalar) {
let q1 = q.invert();
*self = *self * q1;
}
}

impl<'b> MulAssign<&'b Scalar> for Scalar {
fn mul_assign(&mut self, _rhs: &'b Scalar) {
*self = UnpackedScalar::mul(&self.unpack(), &_rhs.unpack()).pack();
Expand Down