Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ECDH - MLKEM hybrid key exchange #1624

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ private enum All
OQS_mlkem512(NamedGroup.OQS_mlkem512, "ML-KEM"),
OQS_mlkem768(NamedGroup.OQS_mlkem768, "ML-KEM"),
OQS_mlkem1024(NamedGroup.OQS_mlkem1024, "ML-KEM"),
OQS_secp256Mlkem512(NamedGroup.OQS_secp256Mlkem512, "ML-KEM"),
OQS_secp384Mlkem768(NamedGroup.OQS_secp384Mlkem768, "ML-KEM"),
OQS_secp521Mlkem1024(NamedGroup.OQS_secp521Mlkem1024, "ML-KEM"),

DRAFT_mlkem768(NamedGroup.DRAFT_mlkem768, "ML-KEM"),
DRAFT_mlkem1024(NamedGroup.DRAFT_mlkem1024, "ML-KEM");

Expand Down
23 changes: 22 additions & 1 deletion tls/src/main/java/org/bouncycastle/tls/NamedGroup.java
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ public class NamedGroup
public static final int OQS_mlkem768 = 0x0248;
/** Experimental API (unstable): unofficial value from Open Quantum Safe project. */
public static final int OQS_mlkem1024 = 0x0249;
/** Experimental API (unstable): unofficial value from Open Quantum Safe project. */
public static final int OQS_secp256Mlkem512 = 0x2F47;
/** Experimental API (unstable): unofficial value from Open Quantum Safe project. */
public static final int OQS_secp384Mlkem768 = 0x2F48;
/** Experimental API (unstable): unofficial value from Open Quantum Safe project. */
public static final int OQS_secp521Mlkem1024 = 0x2F49;

/*
* draft-connolly-tls-mlkem-key-agreement-01
Expand Down Expand Up @@ -310,6 +316,12 @@ public static String getKemName(int namedGroup)
case OQS_mlkem1024:
case DRAFT_mlkem1024:
return "ML-KEM-1024";
case OQS_secp256Mlkem512:
return "secp256-ML-KEM-512";
case OQS_secp384Mlkem768:
return "secp384-ML-KEM-768";
case OQS_secp521Mlkem1024:
return "secp521-ML-KEM-1024";
default:
return null;
}
Expand Down Expand Up @@ -376,7 +388,13 @@ public static String getName(int namedGroup)
return "OQS_mlkem768";
case OQS_mlkem1024:
return "OQS_mlkem1024";
case DRAFT_mlkem768:
case OQS_secp256Mlkem512:
return "OQS_secp256Mlkem512";
case OQS_secp384Mlkem768:
return "OQS_secp384Mlkem768";
case OQS_secp521Mlkem1024:
return "OQS_secp521Mlkem1024";
case DRAFT_mlkem768:
return "DRAFT_mlkem768";
case DRAFT_mlkem1024:
return "DRAFT_mlkem1024";
Expand Down Expand Up @@ -497,6 +515,9 @@ public static boolean refersToASpecificKem(int namedGroup)
case OQS_mlkem512:
case OQS_mlkem768:
case OQS_mlkem1024:
case OQS_secp256Mlkem512:
case OQS_secp384Mlkem768:
case OQS_secp521Mlkem1024:
case DRAFT_mlkem768:
case DRAFT_mlkem1024:
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,15 @@ public TlsECDomain createECDomain(TlsECConfig ecConfig)

public TlsKemDomain createKemDomain(TlsKemConfig kemConfig)
{
return new BcTlsMLKemDomain(this, kemConfig);
switch (kemConfig.getNamedGroup())
{
case NamedGroup.OQS_secp256Mlkem512:
case NamedGroup.OQS_secp384Mlkem768:
case NamedGroup.OQS_secp521Mlkem1024:
return new BcTlsEcdhMlkemDomain(this, kemConfig);
default:
return new BcTlsMLKemDomain(this, kemConfig);
}
}

public TlsNonceGenerator createNonceGenerator(byte[] additionalSeedMaterial)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@
*/
public class BcTlsECDomain implements TlsECDomain
{
public int getPublicKeyByteLength()
{
return (((domainParameters.getCurve().getFieldSize() + 7) / 8) * 2) + 1;
}

public byte[] calculateECDHAgreementBytes(ECPrivateKeyParameters privateKey, ECPublicKeyParameters publicKey)
{
ECDHBasicAgreement basicAgreement = new ECDHBasicAgreement();
basicAgreement.init(privateKey);
BigInteger agreementValue = basicAgreement.calculateAgreement(publicKey);
return BigIntegers.asUnsignedByteArray(basicAgreement.getFieldSize(), agreementValue);
}

public static BcTlsSecret calculateECDHAgreement(BcTlsCrypto crypto, ECPrivateKeyParameters privateKey,
ECPublicKeyParameters publicKey)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package org.bouncycastle.tls.crypto.impl.bc;

import java.io.IOException;

import org.bouncycastle.crypto.AsymmetricCipherKeyPair;
import org.bouncycastle.crypto.SecretWithEncapsulation;
import org.bouncycastle.crypto.params.ECPrivateKeyParameters;
import org.bouncycastle.crypto.params.ECPublicKeyParameters;
import org.bouncycastle.pqc.crypto.crystals.kyber.KyberPrivateKeyParameters;
import org.bouncycastle.pqc.crypto.crystals.kyber.KyberPublicKeyParameters;
import org.bouncycastle.tls.crypto.TlsAgreement;
import org.bouncycastle.tls.crypto.TlsSecret;
import org.bouncycastle.util.Arrays;

public class BcTlsEcdhMlkem implements TlsAgreement
{
protected final BcTlsEcdhMlkemDomain domain;

protected AsymmetricCipherKeyPair ecLocalKeyPair;
protected ECPublicKeyParameters ecPeerPublicKey;
protected AsymmetricCipherKeyPair kyberLocalKeyPair;
protected KyberPublicKeyParameters kyberPeerPublicKey;
protected byte[] kyberCiphertext;
protected byte[] kyberSecret;
protected TlsSecret secret;

public BcTlsEcdhMlkem(BcTlsEcdhMlkemDomain domain)
{
this.domain = domain;
}

public byte[] generateEphemeral() throws IOException
{
this.ecLocalKeyPair = domain.getEcDomain().generateKeyPair();
byte[] ecPublickey = domain.getEcDomain().encodePublicKey((ECPublicKeyParameters)ecLocalKeyPair.getPublic());
if (domain.isServer())
{
return Arrays.concatenate(ecPublickey, kyberCiphertext);
}
else
{
this.kyberLocalKeyPair = domain.getMlkemDomain().generateKeyPair();
byte[] kyberPublicKey = domain.getMlkemDomain().encodePublicKey((KyberPublicKeyParameters)kyberLocalKeyPair.getPublic());
return Arrays.concatenate(ecPublickey, kyberPublicKey);
}
}

public void receivePeerValue(byte[] peerValue) throws IOException
{
this.ecPeerPublicKey = domain.getEcDomain().decodePublicKey(Arrays.copyOf(peerValue, domain.getEcDomain().getPublicKeyByteLength()));
byte[] kyberValue = Arrays.copyOfRange(peerValue, domain.getEcDomain().getPublicKeyByteLength(), peerValue.length);
if (domain.isServer())
{
this.kyberPeerPublicKey = domain.getMlkemDomain().decodePublicKey(kyberValue);
SecretWithEncapsulation encap = domain.getMlkemDomain().encapsulate(kyberPeerPublicKey);
kyberCiphertext = encap.getEncapsulation();
kyberSecret = encap.getSecret();
}
else
{
this.kyberCiphertext = Arrays.clone(kyberValue);
}
}

public TlsSecret calculateSecret() throws IOException
{
byte[] ecSecret = domain.getEcDomain().calculateECDHAgreementBytes((ECPrivateKeyParameters)ecLocalKeyPair.getPrivate(), ecPeerPublicKey);
if (!domain.isServer())
{
kyberSecret = domain.getMlkemDomain().decapsulate((KyberPrivateKeyParameters) kyberLocalKeyPair.getPrivate(), kyberCiphertext);
}
return domain.adoptLocalSecret(Arrays.concatenate(ecSecret, kyberSecret));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package org.bouncycastle.tls.crypto.impl.bc;

import org.bouncycastle.tls.NamedGroup;
import org.bouncycastle.tls.crypto.TlsAgreement;
import org.bouncycastle.tls.crypto.TlsECConfig;
import org.bouncycastle.tls.crypto.TlsKemConfig;
import org.bouncycastle.tls.crypto.TlsKemDomain;

public class BcTlsEcdhMlkemDomain implements TlsKemDomain
{
protected final BcTlsCrypto crypto;
protected final boolean isServer;
private final BcTlsECDomain ecDomain;
private final BcTlsMLKemDomain mlkemDomain;

public BcTlsEcdhMlkemDomain(BcTlsCrypto crypto, TlsKemConfig kemConfig)
{
this.crypto = crypto;
this.ecDomain = getBcTlsECDomain(crypto, kemConfig);
this.mlkemDomain = new BcTlsMLKemDomain(crypto, kemConfig);
this.isServer = kemConfig.isServer();
}

public BcTlsSecret adoptLocalSecret(byte[] secret)
{
return crypto.adoptLocalSecret(secret);
}

public TlsAgreement createKem()
{
return new BcTlsEcdhMlkem(this);
}

public boolean isServer()
{
return isServer;
}

public BcTlsECDomain getEcDomain()
{
return ecDomain;
}

public BcTlsMLKemDomain getMlkemDomain()
{
return mlkemDomain;
}

private BcTlsECDomain getBcTlsECDomain(BcTlsCrypto crypto, TlsKemConfig kemConfig)
{
switch (kemConfig.getNamedGroup())
{
case NamedGroup.OQS_secp256Mlkem512:
return new BcTlsECDomain(crypto, new TlsECConfig(NamedGroup.secp256r1));
case NamedGroup.OQS_secp384Mlkem768:
return new BcTlsECDomain(crypto, new TlsECConfig(NamedGroup.secp384r1));
case NamedGroup.OQS_secp521Mlkem1024:
return new BcTlsECDomain(crypto, new TlsECConfig(NamedGroup.secp521r1));
default:
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public void receivePeerValue(byte[] peerValue) throws IOException
}
else
{
this.secret = domain.decapsulate(privateKey, peerValue);
this.secret = domain.adoptLocalSecret(domain.decapsulate(privateKey, peerValue));
this.privateKey = null;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ protected static KyberParameters getKyberParameters(int namedGroup)
switch (namedGroup)
{
case NamedGroup.OQS_mlkem512:
case NamedGroup.OQS_secp256Mlkem512:
return KyberParameters.kyber512;
case NamedGroup.OQS_mlkem768:
case NamedGroup.OQS_secp384Mlkem768:
case NamedGroup.DRAFT_mlkem768:
return KyberParameters.kyber768;
case NamedGroup.OQS_mlkem1024:
case NamedGroup.OQS_secp521Mlkem1024:
case NamedGroup.DRAFT_mlkem1024:
return KyberParameters.kyber1024;
default:
Expand Down Expand Up @@ -54,11 +57,10 @@ public TlsAgreement createKem()
return new BcTlsMLKem(this);
}

public BcTlsSecret decapsulate(KyberPrivateKeyParameters privateKey, byte[] ciphertext)
public byte[] decapsulate(KyberPrivateKeyParameters privateKey, byte[] ciphertext)
{
KyberKEMExtractor kemExtract = new KyberKEMExtractor(privateKey);
byte[] secret = kemExtract.extractSecret(ciphertext);
return adoptLocalSecret(secret);
return kemExtract.extractSecret(ciphertext);
}

public KyberPublicKeyParameters decodePublicKey(byte[] encoding)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,12 @@ else if (NamedGroup.refersToASpecificKem(namedGroup))
{
switch (namedGroup)
{
case NamedGroup.OQS_secp256Mlkem512:
return ECUtil.getAlgorithmParameters(this, NamedGroup.getCurveName(NamedGroup.secp256r1));
case NamedGroup.OQS_secp384Mlkem768:
return ECUtil.getAlgorithmParameters(this, NamedGroup.getCurveName(NamedGroup.secp384r1));
case NamedGroup.OQS_secp521Mlkem1024:
return ECUtil.getAlgorithmParameters(this, NamedGroup.getCurveName(NamedGroup.secp521r1));
/*
* TODO[tls-kem] Return AlgorithmParameters to check against disabled algorithms?
*/
Expand Down Expand Up @@ -848,7 +854,15 @@ public TlsECDomain createECDomain(TlsECConfig ecConfig)

public TlsKemDomain createKemDomain(TlsKemConfig kemConfig)
{
return new JceTlsMLKemDomain(this, kemConfig);
switch (kemConfig.getNamedGroup())
{
case NamedGroup.OQS_secp256Mlkem512:
case NamedGroup.OQS_secp384Mlkem768:
case NamedGroup.OQS_secp521Mlkem1024:
return new JceTlsEcdhMlkemDomain(this, kemConfig);
default:
return new JceTlsMLKemDomain(this, kemConfig);
}
}

public TlsSecret hkdfInit(int cryptoHashAlgorithm)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,23 @@ public JceTlsECDomain(JcaTlsCrypto crypto, TlsECConfig ecConfig)
throw new IllegalArgumentException("NamedGroup not supported: " + NamedGroup.getText(namedGroup));
}

public int getPublicKeyByteLength()
{
return (((ecCurve.getFieldSize() + 7) / 8) * 2) + 1;
}

public byte[] calculateECDHAgreementBytes(PrivateKey privateKey, PublicKey publicKey) throws IOException
{
try
{
return crypto.calculateKeyAgreement("ECDH", privateKey, publicKey, "TlsPremasterSecret");
}
catch (GeneralSecurityException e)
{
throw new TlsCryptoException("cannot calculate secret", e);
}
}

public JceTlsSecret calculateECDHAgreement(PrivateKey privateKey, PublicKey publicKey)
throws IOException
{
Expand Down
Loading