Skip to content

Commit

Permalink
sm2: add SM2PKE support (#1069)
Browse files Browse the repository at this point in the history
Adds support for the SM2 public key encryption algorithm defined in 
China's national standard GBT.32918.4-2016 (a.k.a. SM2-4)

Closes #1067

Co-authored-by: Tony Arcieri <[email protected]>
  • Loading branch information
heliannuuthus and tarcieri authored Sep 5, 2024
1 parent d382b1a commit 4781762
Show file tree
Hide file tree
Showing 8 changed files with 689 additions and 4 deletions.
5 changes: 3 additions & 2 deletions sm2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ homepage = "https://github.com/RustCrypto/elliptic-curves/tree/master/sm2"
repository = "https://github.com/RustCrypto/elliptic-curves"
readme = "README.md"
categories = ["cryptography", "no-std"]
keywords = ["crypto", "ecc", "shangmi", "signature"]
keywords = ["crypto", "ecc", "shangmi", "signature", "encryption"]
edition = "2021"
rust-version = "1.73"

Expand All @@ -33,13 +33,14 @@ proptest = "1"
rand_core = { version = "0.6", features = ["getrandom"] }

[features]
default = ["arithmetic", "dsa", "pem", "std"]
default = ["arithmetic", "dsa", "pke", "pem", "std"]
alloc = ["elliptic-curve/alloc"]
std = ["alloc", "elliptic-curve/std", "signature?/std"]

arithmetic = ["dep:primeorder", "elliptic-curve/arithmetic"]
bits = ["arithmetic", "elliptic-curve/bits"]
dsa = ["arithmetic", "dep:rfc6979", "dep:signature", "dep:sm3"]
pke = ["arithmetic", "dep:sm3"]
getrandom = ["rand_core/getrandom"]
pem = ["elliptic-curve/pem", "pkcs8"]
pkcs8 = ["elliptic-curve/pkcs8"]
Expand Down
2 changes: 1 addition & 1 deletion sm2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ The SM2 cryptosystem is composed of three distinct algorithms:

- [x] **SM2DSA**: digital signature algorithm defined in [GBT.32918.2-2016], [ISO.IEC.14888-3] (SM2-2)
- [ ] **SM2KEP**: key exchange protocol defined in [GBT.32918.3-2016] (SM2-3)
- [ ] **SM2PKE**: public key encryption algorithm defined in [GBT.32918.4-2016] (SM2-4)
- [x] **SM2PKE**: public key encryption algorithm defined in [GBT.32918.4-2016] (SM2-4)

## Minimum Supported Rust Version

Expand Down
2 changes: 1 addition & 1 deletion sm2/src/arithmetic/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ use core::{
iter::{Product, Sum},
ops::{AddAssign, MulAssign, Neg, SubAssign},
};
use elliptic_curve::ops::Invert;
use elliptic_curve::{
bigint::Limb,
ff::PrimeField,
ops::Invert,
subtle::{Choice, ConstantTimeEq, CtOption},
};

Expand Down
3 changes: 3 additions & 0 deletions sm2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ extern crate alloc;
#[cfg(feature = "dsa")]
pub mod dsa;

#[cfg(feature = "pke")]
pub mod pke;

#[cfg(feature = "arithmetic")]
mod arithmetic;
#[cfg(feature = "dsa")]
Expand Down
178 changes: 178 additions & 0 deletions sm2/src/pke.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
//! SM2 Encryption Algorithm (SM2) as defined in [draft-shen-sm2-ecdsa § 5].
//!
//! ## Usage
//!
//! NOTE: requires the `sm3` crate for digest functions and the `primeorder` crate for prime field operations.
//!
//! The `DecryptingKey` struct is used for decrypting messages that were encrypted using the SM2 encryption algorithm.
//! It is initialized with a `SecretKey` or a non-zero scalar value and can decrypt ciphertexts using the specified decryption mode.
#![cfg_attr(feature = "std", doc = "```")]
#![cfg_attr(not(feature = "std"), doc = "```ignore")]
//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
//! use rand_core::OsRng; // requires 'getrandom` feature
//! use sm2::{
//! pke::{EncryptingKey, Mode},
//! {SecretKey, PublicKey}
//!
//! };
//!
//! // Encrypting
//! let secret_key = SecretKey::random(&mut OsRng); // serialize with `::to_bytes()`
//! let public_key = secret_key.public_key();
//! let encrypting_key = EncryptingKey::new_with_mode(public_key, Mode::C1C2C3);
//! let plaintext = b"plaintext";
//! let ciphertext = encrypting_key.encrypt(plaintext)?;
//!
//! use sm2::pke::DecryptingKey;
//! // Decrypting
//! let decrypting_key = DecryptingKey::new_with_mode(secret_key.to_nonzero_scalar(), Mode::C1C2C3);
//! assert_eq!(decrypting_key.decrypt(&ciphertext)?, plaintext);
//!
//! // Encrypting ASN.1 DER
//! let ciphertext = encrypting_key.encrypt_der(plaintext)?;
//!
//! // Decrypting ASN.1 DER
//! assert_eq!(decrypting_key.decrypt_der(&ciphertext)?, plaintext);
//!
//! Ok(())
//! # }
//! ```
//!
//!
//!
use core::cmp::min;

use crate::AffinePoint;

#[cfg(feature = "alloc")]
use alloc::vec;

use elliptic_curve::{
bigint::{Encoding, Uint, U256},
pkcs8::der::{
asn1::UintRef, Decode, DecodeValue, Encode, Length, Reader, Sequence, Tag, Writer,
},
};

use elliptic_curve::{
pkcs8::der::{asn1::OctetStringRef, EncodeValue},
sec1::ToEncodedPoint,
Result,
};
use sm3::digest::DynDigest;

#[cfg(feature = "arithmetic")]
mod decrypting;
#[cfg(feature = "arithmetic")]
mod encrypting;

#[cfg(feature = "arithmetic")]
pub use self::{decrypting::DecryptingKey, encrypting::EncryptingKey};

/// Modes for the cipher encoding/decoding.
#[derive(Clone, Copy, Debug)]
pub enum Mode {
/// old mode
C1C2C3,
/// new mode
C1C3C2,
}
/// Represents a cipher structure containing encryption-related data (asn.1 format).
///
/// The `Cipher` structure includes the coordinates of the elliptic curve point (`x`, `y`),
/// the digest of the message, and the encrypted cipher text.
pub struct Cipher<'a> {
x: U256,
y: U256,
digest: &'a [u8],
cipher: &'a [u8],
}

impl<'a> Sequence<'a> for Cipher<'a> {}

impl<'a> EncodeValue for Cipher<'a> {
fn value_len(&self) -> elliptic_curve::pkcs8::der::Result<Length> {
UintRef::new(&self.x.to_be_bytes())?.encoded_len()?
+ UintRef::new(&self.y.to_be_bytes())?.encoded_len()?
+ OctetStringRef::new(self.digest)?.encoded_len()?
+ OctetStringRef::new(self.cipher)?.encoded_len()?
}

fn encode_value(&self, writer: &mut impl Writer) -> elliptic_curve::pkcs8::der::Result<()> {
UintRef::new(&self.x.to_be_bytes())?.encode(writer)?;
UintRef::new(&self.y.to_be_bytes())?.encode(writer)?;
OctetStringRef::new(self.digest)?.encode(writer)?;
OctetStringRef::new(self.cipher)?.encode(writer)?;
Ok(())
}
}

impl<'a> DecodeValue<'a> for Cipher<'a> {
type Error = elliptic_curve::pkcs8::der::Error;

fn decode_value<R: Reader<'a>>(
decoder: &mut R,
header: elliptic_curve::pkcs8::der::Header,
) -> core::result::Result<Self, Self::Error> {
decoder.read_nested(header.length, |nr| {
let x = UintRef::decode(nr)?.as_bytes();
let y = UintRef::decode(nr)?.as_bytes();
let digest = OctetStringRef::decode(nr)?.into();
let cipher = OctetStringRef::decode(nr)?.into();
Ok(Cipher {
x: Uint::from_be_bytes(zero_pad_byte_slice(x)?),
y: Uint::from_be_bytes(zero_pad_byte_slice(y)?),
digest,
cipher,
})
})
}
}

/// Performs key derivation using a hash function and elliptic curve point.
fn kdf(hasher: &mut dyn DynDigest, kpb: AffinePoint, c2: &mut [u8]) -> Result<()> {
let klen = c2.len();
let mut ct: i32 = 0x00000001;
let mut offset = 0;
let digest_size = hasher.output_size();
let mut ha = vec![0u8; digest_size];
let encode_point = kpb.to_encoded_point(false);

while offset < klen {
hasher.update(encode_point.x().ok_or(elliptic_curve::Error)?);
hasher.update(encode_point.y().ok_or(elliptic_curve::Error)?);
hasher.update(&ct.to_be_bytes());

hasher
.finalize_into_reset(&mut ha)
.map_err(|_e| elliptic_curve::Error)?;

let xor_len = min(digest_size, klen - offset);
xor(c2, &ha, offset, xor_len);
offset += xor_len;
ct += 1;
}
Ok(())
}

/// XORs a portion of the buffer `c2` with a hash value.
fn xor(c2: &mut [u8], ha: &[u8], offset: usize, xor_len: usize) {
for i in 0..xor_len {
c2[offset + i] ^= ha[i];
}
}

/// Converts a byte slice to a fixed-size array, padding with leading zeroes if necessary.
pub(crate) fn zero_pad_byte_slice<const N: usize>(
bytes: &[u8],
) -> elliptic_curve::pkcs8::der::Result<[u8; N]> {
let num_zeroes = N
.checked_sub(bytes.len())
.ok_or_else(|| Tag::Integer.length_error())?;

// Copy input into `N`-sized output buffer with leading zeroes
let mut output = [0u8; N];
output[num_zeroes..].copy_from_slice(bytes);
Ok(output)
}
Loading

0 comments on commit 4781762

Please sign in to comment.