diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi index 74024d501454..1504f458ca32 100644 --- a/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi +++ b/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi @@ -15,6 +15,7 @@ from cryptography.hazmat.bindings._rust.openssl import ( ed25519, hashes, hmac, + hpke, kdf, keys, poly1305, @@ -34,6 +35,7 @@ __all__ = [ "ed25519", "hashes", "hmac", + "hpke", "kdf", "keys", "openssl_version", diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/hpke.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/hpke.pyi new file mode 100644 index 000000000000..25904206b3c1 --- /dev/null +++ b/src/cryptography/hazmat/bindings/_rust/openssl/hpke.pyi @@ -0,0 +1,32 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from cryptography.hazmat.primitives.asymmetric import x25519 +from cryptography.utils import Buffer + +class KEM: + X25519: KEM + +class KDF: + HKDF_SHA256: KDF + +class AEAD: + AES_128_GCM: AEAD + +class Suite: + def __init__(self, kem: KEM, kdf: KDF, aead: AEAD) -> None: ... + def encrypt( + self, + plaintext: Buffer, + public_key: x25519.X25519PublicKey, + info: Buffer | None = None, + aad: Buffer | None = None, + ) -> bytes: ... + def decrypt( + self, + ciphertext: Buffer, + private_key: x25519.X25519PrivateKey, + info: Buffer | None = None, + aad: Buffer | None = None, + ) -> bytes: ... diff --git a/src/cryptography/hazmat/primitives/hpke.py b/src/cryptography/hazmat/primitives/hpke.py index 5c3fec73479b..e6025159e565 100644 --- a/src/cryptography/hazmat/primitives/hpke.py +++ b/src/cryptography/hazmat/primitives/hpke.py @@ -4,246 +4,12 @@ from __future__ import annotations -import dataclasses -import enum - -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.asymmetric import x25519 -from cryptography.hazmat.primitives.ciphers.aead import AESGCM -from cryptography.hazmat.primitives.kdf.hkdf import HKDF, HKDFExpand -from cryptography.utils import int_to_bytes - -_HPKE_VERSION = b"HPKE-v1" -_HPKE_MODE_BASE = 0x00 - - -class KEM(enum.Enum): - X25519 = "X25519" - - -class KDF(enum.Enum): - HKDF_SHA256 = "HKDF_SHA256" - - -class AEAD(enum.Enum): - AES_128_GCM = "AES_128_GCM" - - -@dataclasses.dataclass(frozen=True) -class _KEMParams: - id: int - nsecret: int - nenc: int - npk: int - nsk: int - hash: hashes.HashAlgorithm - - -@dataclasses.dataclass(frozen=True) -class _KDFParams: - id: int - nh: int - hash: hashes.HashAlgorithm - - -@dataclasses.dataclass(frozen=True) -class _AEADParams: - id: int - nk: int - nn: int - nt: int - - -def _get_kem_params(kem: KEM) -> _KEMParams: - assert kem == KEM.X25519 - return _KEMParams( - id=0x0020, - nsecret=32, - nenc=32, - npk=32, - nsk=32, - hash=hashes.SHA256(), - ) - - -def _get_kdf_params(kdf: KDF) -> _KDFParams: - assert kdf == KDF.HKDF_SHA256 - return _KDFParams( - id=0x0001, - nh=32, - hash=hashes.SHA256(), - ) - - -def _get_aead_params(aead: AEAD) -> _AEADParams: - assert aead == AEAD.AES_128_GCM - return _AEADParams( - id=0x0001, - nk=16, - nn=12, - nt=16, - ) - - -class Suite: - def __init__(self, kem: KEM, kdf: KDF, aead: AEAD) -> None: - if not isinstance(kem, KEM): - raise TypeError("kem must be an instance of KEM") - if not isinstance(kdf, KDF): - raise TypeError("kdf must be an instance of KDF") - if not isinstance(aead, AEAD): - raise TypeError("aead must be an instance of AEAD") - - self._kem = kem - self._kdf = kdf - self._aead = aead - - self._kem_params = _get_kem_params(kem) - self._kdf_params = _get_kdf_params(kdf) - self._aead_params = _get_aead_params(aead) - - # Build suite IDs - self._kem_suite_id = b"KEM" + int_to_bytes(self._kem_params.id, 2) - self._hpke_suite_id = ( - b"HPKE" - + int_to_bytes(self._kem_params.id, 2) - + int_to_bytes(self._kdf_params.id, 2) - + int_to_bytes(self._aead_params.id, 2) - ) - - def _kem_labeled_extract( - self, salt: bytes, label: bytes, ikm: bytes - ) -> bytes: - labeled_ikm = _HPKE_VERSION + self._kem_suite_id + label + ikm - return HKDF.extract( - self._kdf_params.hash, - salt if salt else None, - labeled_ikm, - ) - - def _kem_labeled_expand( - self, prk: bytes, label: bytes, info: bytes, length: int - ) -> bytes: - labeled_info = ( - int_to_bytes(length, 2) - + _HPKE_VERSION - + self._kem_suite_id - + label - + info - ) - hkdf_expand = HKDFExpand( - algorithm=self._kdf_params.hash, - length=length, - info=labeled_info, - ) - return hkdf_expand.derive(prk) - - def _extract_and_expand(self, dh: bytes, kem_context: bytes) -> bytes: - eae_prk = self._kem_labeled_extract(b"", b"eae_prk", dh) - shared_secret = self._kem_labeled_expand( - eae_prk, - b"shared_secret", - kem_context, - self._kem_params.nsecret, - ) - return shared_secret - - def _encap(self, pk_r: x25519.X25519PublicKey) -> tuple[bytes, bytes]: - sk_e = x25519.X25519PrivateKey.generate() - pk_e = sk_e.public_key() - dh = sk_e.exchange(pk_r) - enc = pk_e.public_bytes_raw() - pk_rm = pk_r.public_bytes_raw() - kem_context = enc + pk_rm - shared_secret = self._extract_and_expand(dh, kem_context) - return shared_secret, enc - - def _decap(self, enc: bytes, sk_r: x25519.X25519PrivateKey) -> bytes: - pk_e = x25519.X25519PublicKey.from_public_bytes(enc) - dh = sk_r.exchange(pk_e) - pk_rm = sk_r.public_key().public_bytes_raw() - kem_context = enc + pk_rm - shared_secret = self._extract_and_expand(dh, kem_context) - return shared_secret - - def _hpke_labeled_extract( - self, salt: bytes, label: bytes, ikm: bytes - ) -> bytes: - labeled_ikm = _HPKE_VERSION + self._hpke_suite_id + label + ikm - return HKDF.extract( - self._kdf_params.hash, - salt if salt else None, - labeled_ikm, - ) - - def _hpke_labeled_expand( - self, prk: bytes, label: bytes, info: bytes, length: int - ) -> bytes: - labeled_info = ( - int_to_bytes(length, 2) - + _HPKE_VERSION - + self._hpke_suite_id - + label - + info - ) - hkdf_expand = HKDFExpand( - algorithm=self._kdf_params.hash, - length=length, - info=labeled_info, - ) - return hkdf_expand.derive(prk) - - def _key_schedule( - self, shared_secret: bytes, info: bytes - ) -> tuple[bytes, bytes]: - mode = _HPKE_MODE_BASE - - psk_id_hash = self._hpke_labeled_extract(b"", b"psk_id_hash", b"") - info_hash = self._hpke_labeled_extract(b"", b"info_hash", info) - key_schedule_context = bytes([mode]) + psk_id_hash + info_hash - - secret = self._hpke_labeled_extract(shared_secret, b"secret", b"") - - key = self._hpke_labeled_expand( - secret, b"key", key_schedule_context, self._aead_params.nk - ) - base_nonce = self._hpke_labeled_expand( - secret, - b"base_nonce", - key_schedule_context, - self._aead_params.nn, - ) - - return key, base_nonce - - def encrypt( - self, - plaintext: bytes, - public_key: x25519.X25519PublicKey, - info: bytes = b"", - aad: bytes = b"", - ) -> bytes: - shared_secret, enc = self._encap(public_key) - key, base_nonce = self._key_schedule(shared_secret, info) - aead_impl = AESGCM(key) - ct = aead_impl.encrypt(base_nonce, plaintext, aad) - return enc + ct - - def decrypt( - self, - ciphertext: bytes, - private_key: x25519.X25519PrivateKey, - info: bytes = b"", - aad: bytes = b"", - ) -> bytes: - nenc = self._kem_params.nenc - enc = ciphertext[:nenc] - ct = ciphertext[nenc:] - shared_secret = self._decap(enc, private_key) - key, base_nonce = self._key_schedule(shared_secret, info) - aead_impl = AESGCM(key) - return aead_impl.decrypt(base_nonce, ct, aad) +from cryptography.hazmat.bindings._rust import openssl as rust_openssl +AEAD = rust_openssl.hpke.AEAD +KDF = rust_openssl.hpke.KDF +KEM = rust_openssl.hpke.KEM +Suite = rust_openssl.hpke.Suite __all__ = [ "AEAD", diff --git a/src/rust/src/backend/aead.rs b/src/rust/src/backend/aead.rs index f94da711b8fe..a5fc3478d651 100644 --- a/src/rust/src/backend/aead.rs +++ b/src/rust/src/backend/aead.rs @@ -21,12 +21,12 @@ fn check_length(data: &[u8]) -> CryptographyResult<()> { Ok(()) } -enum Aad<'a> { +pub(crate) enum Aad<'a> { Single(CffiBuf<'a>), List(pyo3::Bound<'a, pyo3::types::PyList>), } -struct EvpCipherAead { +pub(crate) struct EvpCipherAead { base_encryption_ctx: openssl::cipher_ctx::CipherCtx, base_decryption_ctx: openssl::cipher_ctx::CipherCtx, tag_len: usize, @@ -34,7 +34,7 @@ struct EvpCipherAead { } impl EvpCipherAead { - fn new( + pub(crate) fn new( cipher: &openssl::cipher::CipherRef, key: &[u8], tag_len: usize, @@ -127,7 +127,7 @@ impl EvpCipherAead { Ok(()) } - fn encrypt_into( + pub(crate) fn encrypt_into( &self, // We have this arg so we have consistent arguments with encrypt_into in // LazyEvpCipherAead. We can remove it when we remove LazyEvpCipherAead. @@ -192,7 +192,7 @@ impl EvpCipherAead { Ok(()) } - fn decrypt_into( + pub(crate) fn decrypt_into( &self, // We have this arg so we have consistent arguments with decrypt_into in // LazyEvpCipherAead. We can remove it when we remove LazyEvpCipherAead. @@ -626,7 +626,7 @@ impl ChaCha20Poly1305 { name = "AESGCM" )] // NO-COVERAGE-END -struct AesGcm { +pub(crate) struct AesGcm { #[cfg(any( CRYPTOGRAPHY_OPENSSL_320_OR_GREATER, CRYPTOGRAPHY_IS_LIBRESSL, @@ -647,7 +647,10 @@ struct AesGcm { #[pyo3::pymethods] impl AesGcm { #[new] - fn new(py: pyo3::Python<'_>, key: pyo3::Py) -> CryptographyResult { + pub(crate) fn new( + py: pyo3::Python<'_>, + key: pyo3::Py, + ) -> CryptographyResult { let key_buf = key.extract::>(py)?; let cipher = match key_buf.as_bytes().len() { 16 => openssl::cipher::Cipher::aes_128_gcm(), @@ -696,7 +699,7 @@ impl AesGcm { } #[pyo3(signature = (nonce, data, associated_data))] - fn encrypt<'p>( + pub(crate) fn encrypt<'p>( &self, py: pyo3::Python<'p>, nonce: CffiBuf<'_>, @@ -754,7 +757,7 @@ impl AesGcm { } #[pyo3(signature = (nonce, data, associated_data))] - fn decrypt<'p>( + pub(crate) fn decrypt<'p>( &self, py: pyo3::Python<'p>, nonce: CffiBuf<'_>, diff --git a/src/rust/src/backend/hpke.rs b/src/rust/src/backend/hpke.rs new file mode 100644 index 000000000000..7bed017d4591 --- /dev/null +++ b/src/rust/src/backend/hpke.rs @@ -0,0 +1,361 @@ +// This file is dual licensed under the terms of the Apache License, Version +// 2.0, and the BSD License. See the LICENSE file in the root of this repository +// for complete details. + +use crate::backend::aead::AesGcm; +use crate::backend::kdf::{hkdf_extract, HkdfExpand}; +use crate::backend::x25519; +use crate::buf::CffiBuf; +use crate::error::{CryptographyError, CryptographyResult}; +use crate::exceptions; +use crate::types; +use pyo3::types::{PyAnyMethods, PyBytesMethods}; + +const HPKE_VERSION: &[u8] = b"HPKE-v1"; +const HPKE_MODE_BASE: u8 = 0x00; + +mod kem_params { + pub const X25519_ID: u16 = 0x0020; + pub const X25519_NSECRET: usize = 32; + pub const X25519_NENC: usize = 32; +} + +mod kdf_params { + pub const HKDF_SHA256_ID: u16 = 0x0001; +} + +mod aead_params { + pub const AES_128_GCM_ID: u16 = 0x0001; + pub const AES_128_GCM_NK: usize = 16; + pub const AES_128_GCM_NN: usize = 12; +} + +#[allow(clippy::upper_case_acronyms)] +#[pyo3::pyclass( + frozen, + eq, + hash, + module = "cryptography.hazmat.bindings._rust.openssl.hpke" +)] +#[derive(Clone, PartialEq, Eq, Hash)] +pub(crate) enum KEM { + X25519, +} + +#[allow(clippy::upper_case_acronyms)] +#[allow(non_camel_case_types)] +#[pyo3::pyclass( + frozen, + eq, + hash, + module = "cryptography.hazmat.bindings._rust.openssl.hpke" +)] +#[derive(Clone, PartialEq, Eq, Hash)] +pub(crate) enum KDF { + HKDF_SHA256, +} + +#[allow(clippy::upper_case_acronyms)] +#[allow(non_camel_case_types)] +#[pyo3::pyclass( + frozen, + eq, + hash, + module = "cryptography.hazmat.bindings._rust.openssl.hpke" +)] +#[derive(Clone, PartialEq, Eq, Hash)] +pub(crate) enum AEAD { + AES_128_GCM, +} + +#[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.openssl.hpke")] +pub(crate) struct Suite { + kem_suite_id: [u8; 5], + hpke_suite_id: [u8; 10], +} + +impl Suite { + fn hkdf_expand<'p>( + &self, + py: pyo3::Python<'p>, + prk: &[u8], + info: &[u8], + length: usize, + ) -> CryptographyResult> { + let algorithm = types::SHA256.get(py)?.call0()?; + + let mut hkdf_expand = HkdfExpand::new( + py, + algorithm.unbind(), + length, + Some(pyo3::types::PyBytes::new(py, info).unbind()), + None, + )?; + hkdf_expand.derive(py, CffiBuf::from_bytes(py, prk)) + } + + fn kem_labeled_extract( + &self, + py: pyo3::Python<'_>, + label: &[u8], + ikm: &[u8], + ) -> CryptographyResult { + let mut labeled_ikm = Vec::with_capacity(HPKE_VERSION.len() + 5 + label.len() + ikm.len()); + labeled_ikm.extend_from_slice(HPKE_VERSION); + labeled_ikm.extend_from_slice(&self.kem_suite_id); + labeled_ikm.extend_from_slice(label); + labeled_ikm.extend_from_slice(ikm); + + let algorithm = types::SHA256.get(py)?.call0()?; + let buf = CffiBuf::from_bytes(py, &labeled_ikm); + hkdf_extract(py, &algorithm.unbind(), None, &buf) + } + + fn kem_labeled_expand<'p>( + &self, + py: pyo3::Python<'p>, + prk: &[u8], + label: &[u8], + info: &[u8], + length: usize, + ) -> CryptographyResult> { + let mut labeled_info = + Vec::with_capacity(2 + HPKE_VERSION.len() + 5 + label.len() + info.len()); + labeled_info.extend_from_slice(&(length as u16).to_be_bytes()); + labeled_info.extend_from_slice(HPKE_VERSION); + labeled_info.extend_from_slice(&self.kem_suite_id); + labeled_info.extend_from_slice(label); + labeled_info.extend_from_slice(info); + self.hkdf_expand(py, prk, &labeled_info, length) + } + + fn extract_and_expand<'p>( + &self, + py: pyo3::Python<'p>, + dh: &[u8], + kem_context: &[u8], + ) -> CryptographyResult> { + let eae_prk = self.kem_labeled_extract(py, b"eae_prk", dh)?; + self.kem_labeled_expand( + py, + &eae_prk, + b"shared_secret", + kem_context, + kem_params::X25519_NSECRET, + ) + } + + fn encap<'p>( + &self, + py: pyo3::Python<'p>, + pk_r: &pyo3::Bound<'_, pyo3::PyAny>, + ) -> CryptographyResult<( + pyo3::Bound<'p, pyo3::types::PyBytes>, + pyo3::Bound<'p, pyo3::types::PyBytes>, + )> { + let sk_e = pyo3::Bound::new(py, x25519::generate_key()?)?; + let pk_e = sk_e.call_method0(pyo3::intern!(py, "public_key"))?; + + let pk_e_bytes: pyo3::Bound<'p, pyo3::types::PyBytes> = pk_e + .call_method0(pyo3::intern!(py, "public_bytes_raw"))? + .extract()?; + + let pk_r_bytes = pk_r.call_method0(pyo3::intern!(py, "public_bytes_raw"))?; + let pk_r_raw = pk_r_bytes.extract::<&[u8]>()?; + + let dh_result = sk_e.call_method1(pyo3::intern!(py, "exchange"), (pk_r,))?; + let dh = dh_result.extract::<&[u8]>()?; + + let mut kem_context = [0u8; 64]; + kem_context[..32].copy_from_slice(pk_e_bytes.as_bytes()); + kem_context[32..].copy_from_slice(pk_r_raw); + let shared_secret = self.extract_and_expand(py, dh, &kem_context)?; + Ok((shared_secret, pk_e_bytes)) + } + + fn decap<'p>( + &self, + py: pyo3::Python<'p>, + enc: &[u8], + sk_r: &pyo3::Bound<'_, pyo3::PyAny>, + ) -> CryptographyResult> { + // Reconstruct pk_e from enc + let pk_e = pyo3::Bound::new(py, x25519::from_public_bytes(enc)?)?; + + let dh_result = sk_r.call_method1(pyo3::intern!(py, "exchange"), (&pk_e,))?; + let dh = dh_result.extract::<&[u8]>()?; + + let pk_rm = sk_r.call_method0(pyo3::intern!(py, "public_key"))?; + let pk_rm_bytes = pk_rm.call_method0(pyo3::intern!(py, "public_bytes_raw"))?; + let pk_rm_raw = pk_rm_bytes.extract::<&[u8]>()?; + + let mut kem_context = [0u8; 64]; + kem_context[..32].copy_from_slice(enc); + kem_context[32..].copy_from_slice(pk_rm_raw); + self.extract_and_expand(py, dh, &kem_context) + } + + fn hpke_labeled_extract( + &self, + py: pyo3::Python<'_>, + salt: Option<&[u8]>, + label: &[u8], + ikm: &[u8], + ) -> CryptographyResult { + let mut labeled_ikm = Vec::with_capacity(HPKE_VERSION.len() + 10 + label.len() + ikm.len()); + labeled_ikm.extend_from_slice(HPKE_VERSION); + labeled_ikm.extend_from_slice(&self.hpke_suite_id); + labeled_ikm.extend_from_slice(label); + labeled_ikm.extend_from_slice(ikm); + + let algorithm = types::SHA256.get(py)?.call0()?; + let buf = CffiBuf::from_bytes(py, &labeled_ikm); + hkdf_extract(py, &algorithm.unbind(), salt, &buf) + } + + fn hpke_labeled_expand<'p>( + &self, + py: pyo3::Python<'p>, + prk: &[u8], + label: &[u8], + info: &[u8], + length: usize, + ) -> CryptographyResult> { + let mut labeled_info = + Vec::with_capacity(2 + HPKE_VERSION.len() + 10 + label.len() + info.len()); + labeled_info.extend_from_slice(&(length as u16).to_be_bytes()); + labeled_info.extend_from_slice(HPKE_VERSION); + labeled_info.extend_from_slice(&self.hpke_suite_id); + labeled_info.extend_from_slice(label); + labeled_info.extend_from_slice(info); + self.hkdf_expand(py, prk, &labeled_info, length) + } + + fn key_schedule( + &self, + py: pyo3::Python<'_>, + shared_secret: &[u8], + info: &[u8], + ) -> CryptographyResult<( + [u8; aead_params::AES_128_GCM_NK], + [u8; aead_params::AES_128_GCM_NN], + )> { + let psk_id_hash = self.hpke_labeled_extract(py, None, b"psk_id_hash", b"")?; + let info_hash = self.hpke_labeled_extract(py, None, b"info_hash", info)?; + let mut key_schedule_context = vec![HPKE_MODE_BASE]; + key_schedule_context.extend_from_slice(&psk_id_hash); + key_schedule_context.extend_from_slice(&info_hash); + + let secret = self.hpke_labeled_extract(py, Some(shared_secret), b"secret", b"")?; + + let key_vec = self.hpke_labeled_expand( + py, + &secret, + b"key", + &key_schedule_context, + aead_params::AES_128_GCM_NK, + )?; + let nonce_vec = self.hpke_labeled_expand( + py, + &secret, + b"base_nonce", + &key_schedule_context, + aead_params::AES_128_GCM_NN, + )?; + + let mut key = [0u8; aead_params::AES_128_GCM_NK]; + let mut base_nonce = [0u8; aead_params::AES_128_GCM_NN]; + key.copy_from_slice(key_vec.as_bytes()); + base_nonce.copy_from_slice(nonce_vec.as_bytes()); + + Ok((key, base_nonce)) + } +} + +#[pyo3::pymethods] +impl Suite { + #[new] + fn new(_kem: KEM, _kdf: KDF, _aead: AEAD) -> CryptographyResult { + // Build suite IDs + let mut kem_suite_id = [0u8; 5]; + kem_suite_id[..3].copy_from_slice(b"KEM"); + kem_suite_id[3..].copy_from_slice(&kem_params::X25519_ID.to_be_bytes()); + + let mut hpke_suite_id = [0u8; 10]; + hpke_suite_id[..4].copy_from_slice(b"HPKE"); + hpke_suite_id[4..6].copy_from_slice(&kem_params::X25519_ID.to_be_bytes()); + hpke_suite_id[6..8].copy_from_slice(&kdf_params::HKDF_SHA256_ID.to_be_bytes()); + hpke_suite_id[8..10].copy_from_slice(&aead_params::AES_128_GCM_ID.to_be_bytes()); + + Ok(Suite { + kem_suite_id, + hpke_suite_id, + }) + } + + #[pyo3(signature = (plaintext, public_key, info=None, aad=None))] + fn encrypt<'p>( + &self, + py: pyo3::Python<'p>, + plaintext: CffiBuf<'_>, + public_key: &pyo3::Bound<'_, pyo3::PyAny>, + info: Option>, + aad: Option>, + ) -> CryptographyResult> { + let info_bytes: &[u8] = info.as_ref().map(|b| b.as_bytes()).unwrap_or(b""); + + let (shared_secret, enc) = self.encap(py, public_key)?; + let (key, base_nonce) = self.key_schedule(py, shared_secret.as_bytes(), info_bytes)?; + + let aesgcm = AesGcm::new(py, pyo3::types::PyBytes::new(py, &key).unbind().into_any())?; + let ct = aesgcm.encrypt(py, CffiBuf::from_bytes(py, &base_nonce), plaintext, aad)?; + + let enc_bytes = enc.as_bytes(); + let ct_bytes = ct.as_bytes(); + Ok(pyo3::types::PyBytes::new_with( + py, + enc_bytes.len() + ct_bytes.len(), + |buf| { + buf[..enc_bytes.len()].copy_from_slice(enc_bytes); + buf[enc_bytes.len()..].copy_from_slice(ct_bytes); + Ok(()) + }, + )?) + } + + #[pyo3(signature = (ciphertext, private_key, info=None, aad=None))] + fn decrypt<'p>( + &self, + py: pyo3::Python<'p>, + ciphertext: CffiBuf<'_>, + private_key: &pyo3::Bound<'_, pyo3::PyAny>, + info: Option>, + aad: Option>, + ) -> CryptographyResult> { + let ct_bytes = ciphertext.as_bytes(); + if ct_bytes.len() < kem_params::X25519_NENC { + return Err(CryptographyError::from(exceptions::InvalidTag::new_err(()))); + } + + let info_bytes: &[u8] = info.as_ref().map(|b| b.as_bytes()).unwrap_or(b""); + + let (enc, ct) = ct_bytes.split_at(kem_params::X25519_NENC); + + let shared_secret = self.decap(py, enc, private_key)?; + let (key, base_nonce) = self.key_schedule(py, shared_secret.as_bytes(), info_bytes)?; + + let aesgcm = AesGcm::new(py, pyo3::types::PyBytes::new(py, &key).unbind().into_any())?; + aesgcm.decrypt( + py, + CffiBuf::from_bytes(py, &base_nonce), + CffiBuf::from_bytes(py, ct), + aad, + ) + } +} + +#[pyo3::pymodule(gil_used = false)] +pub(crate) mod hpke { + #[pymodule_export] + use super::{Suite, AEAD, KDF, KEM}; +} diff --git a/src/rust/src/backend/kdf.rs b/src/rust/src/backend/kdf.rs index 3c1838097a9d..40aaf0834521 100644 --- a/src/rust/src/backend/kdf.rs +++ b/src/rust/src/backend/kdf.rs @@ -954,23 +954,18 @@ struct Hkdf { used: bool, } -fn hkdf_extract( +pub(crate) fn hkdf_extract( py: pyo3::Python<'_>, algorithm: &pyo3::Py, - salt: Option<&pyo3::Py>, + salt: Option<&[u8]>, key_material: &CffiBuf<'_>, ) -> CryptographyResult { let algorithm_bound = algorithm.bind(py); let digest_size = algorithm_bound .getattr(pyo3::intern!(py, "digest_size"))? .extract::()?; - let salt_bound = salt.map(|s| s.bind(py)); let default_salt = vec![0; digest_size]; - let salt_bytes: &[u8] = if let Some(bound) = salt_bound { - bound.as_bytes() - } else { - &default_salt - }; + let salt_bytes = salt.unwrap_or(&default_salt); let mut hmac = Hmac::new_bytes(py, salt_bytes, algorithm_bound)?; hmac.update_bytes(key_material.as_bytes())?; @@ -999,7 +994,12 @@ impl Hkdf { } let buf = CffiBuf::from_bytes(py, key_material); - let prk = hkdf_extract(py, &self.algorithm, self.salt.as_ref(), &buf)?; + let prk = hkdf_extract( + py, + &self.algorithm, + self.salt.as_ref().map(|s| s.bind(py)).map(|s| s.as_bytes()), + &buf, + )?; let mut hkdf_expand = HkdfExpand::new( py, self.algorithm.clone_ref(py), @@ -1056,10 +1056,10 @@ impl Hkdf { fn extract<'p>( py: pyo3::Python<'p>, algorithm: pyo3::Py, - salt: Option>, + salt: Option<&[u8]>, key_material: CffiBuf<'_>, ) -> CryptographyResult> { - let prk = hkdf_extract(py, &algorithm, salt.as_ref(), &key_material)?; + let prk = hkdf_extract(py, &algorithm, salt, &key_material)?; Ok(pyo3::types::PyBytes::new(py, &prk)) } @@ -1068,7 +1068,12 @@ impl Hkdf { py: pyo3::Python<'p>, key_material: CffiBuf<'_>, ) -> CryptographyResult> { - let prk = hkdf_extract(py, &self.algorithm, self.salt.as_ref(), &key_material)?; + let prk = hkdf_extract( + py, + &self.algorithm, + self.salt.as_ref().map(|s| s.bind(py)).map(|s| s.as_bytes()), + &key_material, + )?; Ok(pyo3::types::PyBytes::new(py, &prk)) } @@ -1118,7 +1123,7 @@ impl Hkdf { name = "HKDFExpand" )] // NO-COVERAGE-END -struct HkdfExpand { +pub(crate) struct HkdfExpand { algorithm: pyo3::Py, info: pyo3::Py, length: usize, @@ -1181,7 +1186,7 @@ impl HkdfExpand { impl HkdfExpand { #[new] #[pyo3(signature = (algorithm, length, info, backend=None))] - fn new( + pub(crate) fn new( py: pyo3::Python<'_>, algorithm: pyo3::Py, length: usize, @@ -1231,7 +1236,7 @@ impl HkdfExpand { self.derive_into_buffer(py, key_material.as_bytes(), buf.as_mut_bytes()) } - fn derive<'p>( + pub(crate) fn derive<'p>( &mut self, py: pyo3::Python<'p>, key_material: CffiBuf<'_>, diff --git a/src/rust/src/backend/mod.rs b/src/rust/src/backend/mod.rs index aceea4c166c9..a9133cafb8c8 100644 --- a/src/rust/src/backend/mod.rs +++ b/src/rust/src/backend/mod.rs @@ -18,6 +18,7 @@ pub(crate) mod ed25519; pub(crate) mod ed448; pub(crate) mod hashes; pub(crate) mod hmac; +pub(crate) mod hpke; pub(crate) mod kdf; pub(crate) mod keys; pub(crate) mod poly1305; diff --git a/src/rust/src/backend/x25519.rs b/src/rust/src/backend/x25519.rs index 9ee092725aec..fac7a9d46ffc 100644 --- a/src/rust/src/backend/x25519.rs +++ b/src/rust/src/backend/x25519.rs @@ -17,7 +17,7 @@ pub(crate) struct X25519PublicKey { } #[pyo3::pyfunction] -fn generate_key() -> CryptographyResult { +pub(crate) fn generate_key() -> CryptographyResult { Ok(X25519PrivateKey { pkey: openssl::pkey::PKey::generate_x25519()?, }) @@ -52,7 +52,7 @@ fn from_private_bytes(data: CffiBuf<'_>) -> pyo3::PyResult { } #[pyo3::pyfunction] -fn from_public_bytes(data: &[u8]) -> pyo3::PyResult { +pub(crate) fn from_public_bytes(data: &[u8]) -> pyo3::PyResult { let pkey = openssl::pkey::PKey::public_key_from_raw_bytes(data, openssl::pkey::Id::X25519) .map_err(|_| { pyo3::exceptions::PyValueError::new_err("An X25519 public key is 32 bytes long") diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index 32067299f467..dff16a94dc01 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -231,6 +231,8 @@ mod _rust { #[pymodule_export] use crate::backend::hmac::hmac; #[pymodule_export] + use crate::backend::hpke::hpke; + #[pymodule_export] use crate::backend::kdf::kdf; #[pymodule_export] use crate::backend::keys::keys; diff --git a/tests/hazmat/primitives/test_hpke.py b/tests/hazmat/primitives/test_hpke.py index f257aeb486da..39e6176f9833 100644 --- a/tests/hazmat/primitives/test_hpke.py +++ b/tests/hazmat/primitives/test_hpke.py @@ -31,17 +31,15 @@ ) class TestHPKE: def test_invalid_kem_type(self): - with pytest.raises(TypeError, match="kem must be an instance of KEM"): + with pytest.raises(TypeError): Suite("not a kem", KDF.HKDF_SHA256, AEAD.AES_128_GCM) # type: ignore[arg-type] def test_invalid_kdf_type(self): - with pytest.raises(TypeError, match="kdf must be an instance of KDF"): + with pytest.raises(TypeError): Suite(KEM.X25519, "not a kdf", AEAD.AES_128_GCM) # type: ignore[arg-type] def test_invalid_aead_type(self): - with pytest.raises( - TypeError, match="aead must be an instance of AEAD" - ): + with pytest.raises(TypeError): Suite(KEM.X25519, KDF.HKDF_SHA256, "not an aead") # type: ignore[arg-type] @pytest.mark.parametrize("kem,kdf,aead", SUPPORTED_SUITES) @@ -172,6 +170,8 @@ def test_truncated_ciphertext(self): with pytest.raises(InvalidTag): suite.decrypt(truncated, sk_r) + with pytest.raises(InvalidTag): + suite.decrypt(b"\x00", sk_r) def test_vector_decryption(self, subtests): vectors = load_vectors_from_file(