From 8b76d68a9c74b91c34e0762c8d8589bf9a6cded7 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 28 Jan 2026 23:33:10 +0000 Subject: [PATCH 1/7] Migrate HPKE implementation from Python to Rust Move the HPKE (Hybrid Public Key Encryption) implementation from pure Python to Rust for improved performance. The implementation follows RFC 9180 and includes: - KEM: DHKEM(X25519, HKDF-SHA256) - KDF: HKDF-SHA256 - AEAD: AES-128-GCM - Mode: Base mode only The public API remains unchanged - KEM, KDF, AEAD enums and Suite class are now backed by Rust via PyO3. https://claude.ai/code/session_01W43m9LudrvqkHKr4BrMa7c --- .../bindings/_rust/openssl/__init__.pyi | 2 + .../hazmat/bindings/_rust/openssl/hpke.pyi | 44 ++ src/cryptography/hazmat/primitives/hpke.py | 244 +------- src/rust/src/backend/hpke.rs | 555 ++++++++++++++++++ src/rust/src/backend/mod.rs | 1 + src/rust/src/backend/x25519.rs | 20 + src/rust/src/lib.rs | 2 + 7 files changed, 629 insertions(+), 239 deletions(-) create mode 100644 src/cryptography/hazmat/bindings/_rust/openssl/hpke.pyi create mode 100644 src/rust/src/backend/hpke.rs diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi index 74024d501454..1504f458ca32 100644 --- a/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi +++ b/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi @@ -15,6 +15,7 @@ from cryptography.hazmat.bindings._rust.openssl import ( ed25519, hashes, hmac, + hpke, kdf, keys, poly1305, @@ -34,6 +35,7 @@ __all__ = [ "ed25519", "hashes", "hmac", + "hpke", "kdf", "keys", "openssl_version", diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/hpke.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/hpke.pyi new file mode 100644 index 000000000000..609b6587da55 --- /dev/null +++ b/src/cryptography/hazmat/bindings/_rust/openssl/hpke.pyi @@ -0,0 +1,44 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from cryptography.hazmat.primitives.asymmetric import x25519 +from cryptography.utils import Buffer + +class KEM: + X25519: KEM + @property + def value(self) -> str: ... + def __eq__(self, other: object) -> bool: ... + def __hash__(self) -> int: ... + +class KDF: + HKDF_SHA256: KDF + @property + def value(self) -> str: ... + def __eq__(self, other: object) -> bool: ... + def __hash__(self) -> int: ... + +class AEAD: + AES_128_GCM: AEAD + @property + def value(self) -> str: ... + def __eq__(self, other: object) -> bool: ... + def __hash__(self) -> int: ... + +class Suite: + def __init__(self, kem: KEM, kdf: KDF, aead: AEAD) -> None: ... + def encrypt( + self, + plaintext: Buffer, + public_key: x25519.X25519PublicKey, + info: Buffer | None = None, + aad: Buffer | None = None, + ) -> bytes: ... + def decrypt( + self, + ciphertext: Buffer, + private_key: x25519.X25519PrivateKey, + info: Buffer | None = None, + aad: Buffer | None = None, + ) -> bytes: ... diff --git a/src/cryptography/hazmat/primitives/hpke.py b/src/cryptography/hazmat/primitives/hpke.py index 5c3fec73479b..e6025159e565 100644 --- a/src/cryptography/hazmat/primitives/hpke.py +++ b/src/cryptography/hazmat/primitives/hpke.py @@ -4,246 +4,12 @@ from __future__ import annotations -import dataclasses -import enum - -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.asymmetric import x25519 -from cryptography.hazmat.primitives.ciphers.aead import AESGCM -from cryptography.hazmat.primitives.kdf.hkdf import HKDF, HKDFExpand -from cryptography.utils import int_to_bytes - -_HPKE_VERSION = b"HPKE-v1" -_HPKE_MODE_BASE = 0x00 - - -class KEM(enum.Enum): - X25519 = "X25519" - - -class KDF(enum.Enum): - HKDF_SHA256 = "HKDF_SHA256" - - -class AEAD(enum.Enum): - AES_128_GCM = "AES_128_GCM" - - -@dataclasses.dataclass(frozen=True) -class _KEMParams: - id: int - nsecret: int - nenc: int - npk: int - nsk: int - hash: hashes.HashAlgorithm - - -@dataclasses.dataclass(frozen=True) -class _KDFParams: - id: int - nh: int - hash: hashes.HashAlgorithm - - -@dataclasses.dataclass(frozen=True) -class _AEADParams: - id: int - nk: int - nn: int - nt: int - - -def _get_kem_params(kem: KEM) -> _KEMParams: - assert kem == KEM.X25519 - return _KEMParams( - id=0x0020, - nsecret=32, - nenc=32, - npk=32, - nsk=32, - hash=hashes.SHA256(), - ) - - -def _get_kdf_params(kdf: KDF) -> _KDFParams: - assert kdf == KDF.HKDF_SHA256 - return _KDFParams( - id=0x0001, - nh=32, - hash=hashes.SHA256(), - ) - - -def _get_aead_params(aead: AEAD) -> _AEADParams: - assert aead == AEAD.AES_128_GCM - return _AEADParams( - id=0x0001, - nk=16, - nn=12, - nt=16, - ) - - -class Suite: - def __init__(self, kem: KEM, kdf: KDF, aead: AEAD) -> None: - if not isinstance(kem, KEM): - raise TypeError("kem must be an instance of KEM") - if not isinstance(kdf, KDF): - raise TypeError("kdf must be an instance of KDF") - if not isinstance(aead, AEAD): - raise TypeError("aead must be an instance of AEAD") - - self._kem = kem - self._kdf = kdf - self._aead = aead - - self._kem_params = _get_kem_params(kem) - self._kdf_params = _get_kdf_params(kdf) - self._aead_params = _get_aead_params(aead) - - # Build suite IDs - self._kem_suite_id = b"KEM" + int_to_bytes(self._kem_params.id, 2) - self._hpke_suite_id = ( - b"HPKE" - + int_to_bytes(self._kem_params.id, 2) - + int_to_bytes(self._kdf_params.id, 2) - + int_to_bytes(self._aead_params.id, 2) - ) - - def _kem_labeled_extract( - self, salt: bytes, label: bytes, ikm: bytes - ) -> bytes: - labeled_ikm = _HPKE_VERSION + self._kem_suite_id + label + ikm - return HKDF.extract( - self._kdf_params.hash, - salt if salt else None, - labeled_ikm, - ) - - def _kem_labeled_expand( - self, prk: bytes, label: bytes, info: bytes, length: int - ) -> bytes: - labeled_info = ( - int_to_bytes(length, 2) - + _HPKE_VERSION - + self._kem_suite_id - + label - + info - ) - hkdf_expand = HKDFExpand( - algorithm=self._kdf_params.hash, - length=length, - info=labeled_info, - ) - return hkdf_expand.derive(prk) - - def _extract_and_expand(self, dh: bytes, kem_context: bytes) -> bytes: - eae_prk = self._kem_labeled_extract(b"", b"eae_prk", dh) - shared_secret = self._kem_labeled_expand( - eae_prk, - b"shared_secret", - kem_context, - self._kem_params.nsecret, - ) - return shared_secret - - def _encap(self, pk_r: x25519.X25519PublicKey) -> tuple[bytes, bytes]: - sk_e = x25519.X25519PrivateKey.generate() - pk_e = sk_e.public_key() - dh = sk_e.exchange(pk_r) - enc = pk_e.public_bytes_raw() - pk_rm = pk_r.public_bytes_raw() - kem_context = enc + pk_rm - shared_secret = self._extract_and_expand(dh, kem_context) - return shared_secret, enc - - def _decap(self, enc: bytes, sk_r: x25519.X25519PrivateKey) -> bytes: - pk_e = x25519.X25519PublicKey.from_public_bytes(enc) - dh = sk_r.exchange(pk_e) - pk_rm = sk_r.public_key().public_bytes_raw() - kem_context = enc + pk_rm - shared_secret = self._extract_and_expand(dh, kem_context) - return shared_secret - - def _hpke_labeled_extract( - self, salt: bytes, label: bytes, ikm: bytes - ) -> bytes: - labeled_ikm = _HPKE_VERSION + self._hpke_suite_id + label + ikm - return HKDF.extract( - self._kdf_params.hash, - salt if salt else None, - labeled_ikm, - ) - - def _hpke_labeled_expand( - self, prk: bytes, label: bytes, info: bytes, length: int - ) -> bytes: - labeled_info = ( - int_to_bytes(length, 2) - + _HPKE_VERSION - + self._hpke_suite_id - + label - + info - ) - hkdf_expand = HKDFExpand( - algorithm=self._kdf_params.hash, - length=length, - info=labeled_info, - ) - return hkdf_expand.derive(prk) - - def _key_schedule( - self, shared_secret: bytes, info: bytes - ) -> tuple[bytes, bytes]: - mode = _HPKE_MODE_BASE - - psk_id_hash = self._hpke_labeled_extract(b"", b"psk_id_hash", b"") - info_hash = self._hpke_labeled_extract(b"", b"info_hash", info) - key_schedule_context = bytes([mode]) + psk_id_hash + info_hash - - secret = self._hpke_labeled_extract(shared_secret, b"secret", b"") - - key = self._hpke_labeled_expand( - secret, b"key", key_schedule_context, self._aead_params.nk - ) - base_nonce = self._hpke_labeled_expand( - secret, - b"base_nonce", - key_schedule_context, - self._aead_params.nn, - ) - - return key, base_nonce - - def encrypt( - self, - plaintext: bytes, - public_key: x25519.X25519PublicKey, - info: bytes = b"", - aad: bytes = b"", - ) -> bytes: - shared_secret, enc = self._encap(public_key) - key, base_nonce = self._key_schedule(shared_secret, info) - aead_impl = AESGCM(key) - ct = aead_impl.encrypt(base_nonce, plaintext, aad) - return enc + ct - - def decrypt( - self, - ciphertext: bytes, - private_key: x25519.X25519PrivateKey, - info: bytes = b"", - aad: bytes = b"", - ) -> bytes: - nenc = self._kem_params.nenc - enc = ciphertext[:nenc] - ct = ciphertext[nenc:] - shared_secret = self._decap(enc, private_key) - key, base_nonce = self._key_schedule(shared_secret, info) - aead_impl = AESGCM(key) - return aead_impl.decrypt(base_nonce, ct, aad) +from cryptography.hazmat.bindings._rust import openssl as rust_openssl +AEAD = rust_openssl.hpke.AEAD +KDF = rust_openssl.hpke.KDF +KEM = rust_openssl.hpke.KEM +Suite = rust_openssl.hpke.Suite __all__ = [ "AEAD", diff --git a/src/rust/src/backend/hpke.rs b/src/rust/src/backend/hpke.rs new file mode 100644 index 000000000000..9e32f018bd6f --- /dev/null +++ b/src/rust/src/backend/hpke.rs @@ -0,0 +1,555 @@ +// This file is dual licensed under the terms of the Apache License, Version +// 2.0, and the BSD License. See the LICENSE file in the root of this repository +// for complete details. + +use crate::backend::hmac::Hmac; +use crate::backend::x25519; +use crate::buf::CffiBuf; +use crate::error::{CryptographyError, CryptographyResult}; +use crate::exceptions; +use crate::types; +use pyo3::types::{PyAnyMethods, PyBytesMethods}; + +const HPKE_VERSION: &[u8] = b"HPKE-v1"; +const HPKE_MODE_BASE: u8 = 0x00; + +// KEM parameters for X25519 (DHKEM(X25519, HKDF-SHA256)) +const KEM_ID: u16 = 0x0020; +const KEM_NSECRET: usize = 32; +const KEM_NENC: usize = 32; + +// KDF parameters for HKDF-SHA256 +const KDF_ID: u16 = 0x0001; + +// AEAD parameters for AES-128-GCM +const AEAD_ID: u16 = 0x0001; +const AEAD_NK: usize = 16; +const AEAD_NN: usize = 12; +const AEAD_NT: usize = 16; + +fn int_to_bytes(value: u16, length: usize) -> Vec { + let bytes = value.to_be_bytes(); + if length == 1 { + vec![bytes[1]] + } else { + bytes.to_vec() + } +} + +#[allow(clippy::upper_case_acronyms)] +#[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.openssl.hpke")] +pub(crate) struct KEM { + _value: String, +} + +#[pyo3::pymethods] +impl KEM { + #[classattr] + #[pyo3(name = "X25519")] + fn x25519() -> KEM { + KEM { + _value: "X25519".to_string(), + } + } + + fn __eq__(&self, other: &KEM) -> bool { + self._value == other._value + } + + fn __hash__(&self) -> u64 { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let mut hasher = DefaultHasher::new(); + self._value.hash(&mut hasher); + hasher.finish() + } + + #[getter] + fn value(&self) -> &str { + &self._value + } +} + +#[allow(clippy::upper_case_acronyms)] +#[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.openssl.hpke")] +pub(crate) struct KDF { + _value: String, +} + +#[pyo3::pymethods] +impl KDF { + #[classattr] + #[pyo3(name = "HKDF_SHA256")] + fn hkdf_sha256() -> KDF { + KDF { + _value: "HKDF_SHA256".to_string(), + } + } + + fn __eq__(&self, other: &KDF) -> bool { + self._value == other._value + } + + fn __hash__(&self) -> u64 { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let mut hasher = DefaultHasher::new(); + self._value.hash(&mut hasher); + hasher.finish() + } + + #[getter] + fn value(&self) -> &str { + &self._value + } +} + +#[allow(clippy::upper_case_acronyms)] +#[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.openssl.hpke")] +pub(crate) struct AEAD { + _value: String, +} + +#[pyo3::pymethods] +impl AEAD { + #[classattr] + #[pyo3(name = "AES_128_GCM")] + fn aes_128_gcm() -> AEAD { + AEAD { + _value: "AES_128_GCM".to_string(), + } + } + + fn __eq__(&self, other: &AEAD) -> bool { + self._value == other._value + } + + fn __hash__(&self) -> u64 { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let mut hasher = DefaultHasher::new(); + self._value.hash(&mut hasher); + hasher.finish() + } + + #[getter] + fn value(&self) -> &str { + &self._value + } +} + +#[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.openssl.hpke")] +pub(crate) struct Suite { + kem_suite_id: Vec, + hpke_suite_id: Vec, +} + +impl Suite { + fn hkdf_extract( + &self, + py: pyo3::Python<'_>, + salt: &[u8], + ikm: &[u8], + ) -> CryptographyResult> { + let sha256 = types::SHA256.get(py)?.call0()?; + let digest_size = sha256 + .getattr(pyo3::intern!(py, "digest_size"))? + .extract::()?; + let default_salt = vec![0u8; digest_size]; + let salt_bytes = if salt.is_empty() { &default_salt } else { salt }; + let mut hmac = Hmac::new_bytes(py, salt_bytes, &sha256)?; + hmac.update_bytes(ikm)?; + let result = hmac.finalize_bytes()?; + Ok(result.to_vec()) + } + + fn hkdf_expand( + &self, + py: pyo3::Python<'_>, + prk: &[u8], + info: &[u8], + length: usize, + ) -> CryptographyResult> { + let sha256 = types::SHA256.get(py)?.call0()?; + let digest_size = sha256 + .getattr(pyo3::intern!(py, "digest_size"))? + .extract::()?; + + let h_prime = Hmac::new_bytes(py, prk, &sha256)?; + + let mut output = vec![0u8; length]; + let mut pos = 0usize; + let mut counter = 0u8; + + while pos < length { + counter += 1; + let mut h = h_prime.copy(py)?; + + let start = pos.saturating_sub(digest_size); + h.update_bytes(&output[start..pos])?; + h.update_bytes(info)?; + h.update_bytes(&[counter])?; + + let block = h.finalize(py)?; + let block_bytes = block.as_bytes(); + + let copy_len = (length - pos).min(digest_size); + output[pos..pos + copy_len].copy_from_slice(&block_bytes[..copy_len]); + pos += copy_len; + } + + Ok(output) + } + + fn kem_labeled_extract( + &self, + py: pyo3::Python<'_>, + salt: &[u8], + label: &[u8], + ikm: &[u8], + ) -> CryptographyResult> { + let mut labeled_ikm = Vec::new(); + labeled_ikm.extend_from_slice(HPKE_VERSION); + labeled_ikm.extend_from_slice(&self.kem_suite_id); + labeled_ikm.extend_from_slice(label); + labeled_ikm.extend_from_slice(ikm); + self.hkdf_extract(py, salt, &labeled_ikm) + } + + fn kem_labeled_expand( + &self, + py: pyo3::Python<'_>, + prk: &[u8], + label: &[u8], + info: &[u8], + length: usize, + ) -> CryptographyResult> { + let mut labeled_info = Vec::new(); + labeled_info.extend_from_slice(&int_to_bytes(length as u16, 2)); + labeled_info.extend_from_slice(HPKE_VERSION); + labeled_info.extend_from_slice(&self.kem_suite_id); + labeled_info.extend_from_slice(label); + labeled_info.extend_from_slice(info); + self.hkdf_expand(py, prk, &labeled_info, length) + } + + fn extract_and_expand( + &self, + py: pyo3::Python<'_>, + dh: &[u8], + kem_context: &[u8], + ) -> CryptographyResult> { + let eae_prk = self.kem_labeled_extract(py, b"", b"eae_prk", dh)?; + self.kem_labeled_expand(py, &eae_prk, b"shared_secret", kem_context, KEM_NSECRET) + } + + fn encap( + &self, + py: pyo3::Python<'_>, + pk_r: &x25519::X25519PublicKey, + ) -> CryptographyResult<(Vec, Vec)> { + // Generate ephemeral key pair using OpenSSL directly + let sk_e_pkey = openssl::pkey::PKey::generate_x25519()?; + let pk_e_raw = sk_e_pkey.raw_public_key()?; + + // Exchange using the ephemeral private key and recipient's public key + let pk_r_raw = pk_r.public_bytes_raw_internal(py)?; + let pk_r_pkey = + openssl::pkey::PKey::public_key_from_raw_bytes(&pk_r_raw, openssl::pkey::Id::X25519)?; + + let mut deriver = openssl::derive::Deriver::new(&sk_e_pkey)?; + deriver.set_peer(&pk_r_pkey)?; + let mut dh = vec![0u8; deriver.len()?]; + let n = deriver.derive(&mut dh)?; + assert_eq!(n, dh.len()); + + let mut kem_context = Vec::new(); + kem_context.extend_from_slice(&pk_e_raw); + kem_context.extend_from_slice(&pk_r_raw); + let shared_secret = self.extract_and_expand(py, &dh, &kem_context)?; + Ok((shared_secret, pk_e_raw)) + } + + fn decap( + &self, + py: pyo3::Python<'_>, + enc: &[u8], + sk_r: &x25519::X25519PrivateKey, + ) -> CryptographyResult> { + // Reconstruct pk_e from enc + let pk_e_pkey = + openssl::pkey::PKey::public_key_from_raw_bytes(enc, openssl::pkey::Id::X25519) + .map_err(|_| { + CryptographyError::from(pyo3::exceptions::PyValueError::new_err( + "Invalid encapsulated key", + )) + })?; + + // Get our private key for ECDH + let sk_r_raw = sk_r.private_bytes_raw_internal(py)?; + let sk_r_pkey = + openssl::pkey::PKey::private_key_from_raw_bytes(&sk_r_raw, openssl::pkey::Id::X25519)?; + + // Perform ECDH + let mut deriver = openssl::derive::Deriver::new(&sk_r_pkey)?; + deriver.set_peer(&pk_e_pkey)?; + let mut dh = vec![0u8; deriver.len()?]; + let n = deriver.derive(&mut dh)?; + assert_eq!(n, dh.len()); + + // Get our public key + let pk_rm = sk_r_pkey.raw_public_key()?; + + let mut kem_context = Vec::new(); + kem_context.extend_from_slice(enc); + kem_context.extend_from_slice(&pk_rm); + self.extract_and_expand(py, &dh, &kem_context) + } + + fn hpke_labeled_extract( + &self, + py: pyo3::Python<'_>, + salt: &[u8], + label: &[u8], + ikm: &[u8], + ) -> CryptographyResult> { + let mut labeled_ikm = Vec::new(); + labeled_ikm.extend_from_slice(HPKE_VERSION); + labeled_ikm.extend_from_slice(&self.hpke_suite_id); + labeled_ikm.extend_from_slice(label); + labeled_ikm.extend_from_slice(ikm); + self.hkdf_extract(py, salt, &labeled_ikm) + } + + fn hpke_labeled_expand( + &self, + py: pyo3::Python<'_>, + prk: &[u8], + label: &[u8], + info: &[u8], + length: usize, + ) -> CryptographyResult> { + let mut labeled_info = Vec::new(); + labeled_info.extend_from_slice(&int_to_bytes(length as u16, 2)); + labeled_info.extend_from_slice(HPKE_VERSION); + labeled_info.extend_from_slice(&self.hpke_suite_id); + labeled_info.extend_from_slice(label); + labeled_info.extend_from_slice(info); + self.hkdf_expand(py, prk, &labeled_info, length) + } + + fn key_schedule( + &self, + py: pyo3::Python<'_>, + shared_secret: &[u8], + info: &[u8], + ) -> CryptographyResult<(Vec, Vec)> { + let psk_id_hash = self.hpke_labeled_extract(py, b"", b"psk_id_hash", b"")?; + let info_hash = self.hpke_labeled_extract(py, b"", b"info_hash", info)?; + let mut key_schedule_context = vec![HPKE_MODE_BASE]; + key_schedule_context.extend_from_slice(&psk_id_hash); + key_schedule_context.extend_from_slice(&info_hash); + + let secret = self.hpke_labeled_extract(py, shared_secret, b"secret", b"")?; + + let key = self.hpke_labeled_expand(py, &secret, b"key", &key_schedule_context, AEAD_NK)?; + let base_nonce = + self.hpke_labeled_expand(py, &secret, b"base_nonce", &key_schedule_context, AEAD_NN)?; + + Ok((key, base_nonce)) + } + + fn aead_encrypt( + &self, + key: &[u8], + nonce: &[u8], + plaintext: &[u8], + aad: &[u8], + ) -> CryptographyResult> { + let cipher = match key.len() { + 16 => openssl::cipher::Cipher::aes_128_gcm(), + 24 => openssl::cipher::Cipher::aes_192_gcm(), + 32 => openssl::cipher::Cipher::aes_256_gcm(), + _ => { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Invalid key length"), + )) + } + }; + + let mut ctx = openssl::cipher_ctx::CipherCtx::new()?; + ctx.encrypt_init(Some(cipher), Some(key), None)?; + ctx.set_iv_length(nonce.len())?; + ctx.encrypt_init(None, None, Some(nonce))?; + + // Process AAD + if !aad.is_empty() { + ctx.cipher_update(aad, None)?; + } + + // Encrypt plaintext + let mut ciphertext = vec![0u8; plaintext.len() + AEAD_NT]; + let n = ctx.cipher_update(plaintext, Some(&mut ciphertext[..plaintext.len()]))?; + assert_eq!(n, plaintext.len()); + + let mut final_block = [0u8; 0]; + let n = ctx.cipher_final(&mut final_block)?; + assert_eq!(n, 0); + + // Get tag + ctx.tag(&mut ciphertext[plaintext.len()..]) + .map_err(CryptographyError::from)?; + + Ok(ciphertext) + } + + fn aead_decrypt( + &self, + key: &[u8], + nonce: &[u8], + ciphertext: &[u8], + aad: &[u8], + ) -> CryptographyResult> { + if ciphertext.len() < AEAD_NT { + return Err(CryptographyError::from(exceptions::InvalidTag::new_err(()))); + } + + let cipher = match key.len() { + 16 => openssl::cipher::Cipher::aes_128_gcm(), + 24 => openssl::cipher::Cipher::aes_192_gcm(), + 32 => openssl::cipher::Cipher::aes_256_gcm(), + _ => { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Invalid key length"), + )) + } + }; + + let ct_len = ciphertext.len() - AEAD_NT; + let (ct_data, tag) = ciphertext.split_at(ct_len); + + let mut ctx = openssl::cipher_ctx::CipherCtx::new()?; + ctx.decrypt_init(Some(cipher), Some(key), None)?; + ctx.set_iv_length(nonce.len())?; + ctx.decrypt_init(None, None, Some(nonce))?; + ctx.set_tag(tag)?; + + // Process AAD + if !aad.is_empty() { + ctx.cipher_update(aad, None)?; + } + + // Decrypt ciphertext + let mut plaintext = vec![0u8; ct_len]; + let n = ctx + .cipher_update(ct_data, Some(&mut plaintext)) + .map_err(|_| exceptions::InvalidTag::new_err(()))?; + assert_eq!(n, ct_len); + + ctx.cipher_final(&mut []) + .map_err(|_| exceptions::InvalidTag::new_err(()))?; + + Ok(plaintext) + } +} + +#[pyo3::pymethods] +impl Suite { + #[new] + fn new( + _py: pyo3::Python<'_>, + kem: &pyo3::Bound<'_, pyo3::PyAny>, + kdf: &pyo3::Bound<'_, pyo3::PyAny>, + aead: &pyo3::Bound<'_, pyo3::PyAny>, + ) -> CryptographyResult { + // Validate types + if !kem.is_instance_of::() { + return Err(CryptographyError::from( + pyo3::exceptions::PyTypeError::new_err("kem must be an instance of KEM"), + )); + } + if !kdf.is_instance_of::() { + return Err(CryptographyError::from( + pyo3::exceptions::PyTypeError::new_err("kdf must be an instance of KDF"), + )); + } + if !aead.is_instance_of::() { + return Err(CryptographyError::from( + pyo3::exceptions::PyTypeError::new_err("aead must be an instance of AEAD"), + )); + } + + // Build suite IDs + let mut kem_suite_id = Vec::new(); + kem_suite_id.extend_from_slice(b"KEM"); + kem_suite_id.extend_from_slice(&int_to_bytes(KEM_ID, 2)); + + let mut hpke_suite_id = Vec::new(); + hpke_suite_id.extend_from_slice(b"HPKE"); + hpke_suite_id.extend_from_slice(&int_to_bytes(KEM_ID, 2)); + hpke_suite_id.extend_from_slice(&int_to_bytes(KDF_ID, 2)); + hpke_suite_id.extend_from_slice(&int_to_bytes(AEAD_ID, 2)); + + Ok(Suite { + kem_suite_id, + hpke_suite_id, + }) + } + + #[pyo3(signature = (plaintext, public_key, info=None, aad=None))] + fn encrypt<'p>( + &self, + py: pyo3::Python<'p>, + plaintext: CffiBuf<'_>, + public_key: &x25519::X25519PublicKey, + info: Option>, + aad: Option>, + ) -> CryptographyResult> { + let info_bytes = info.map(|b| b.as_bytes().to_vec()).unwrap_or_default(); + let aad_bytes = aad.map(|b| b.as_bytes().to_vec()).unwrap_or_default(); + + let (shared_secret, enc) = self.encap(py, public_key)?; + let (key, base_nonce) = self.key_schedule(py, &shared_secret, &info_bytes)?; + let ct = self.aead_encrypt(&key, &base_nonce, plaintext.as_bytes(), &aad_bytes)?; + + // Combine enc + ct + let mut result = Vec::with_capacity(enc.len() + ct.len()); + result.extend_from_slice(&enc); + result.extend_from_slice(&ct); + + Ok(pyo3::types::PyBytes::new(py, &result)) + } + + #[pyo3(signature = (ciphertext, private_key, info=None, aad=None))] + fn decrypt<'p>( + &self, + py: pyo3::Python<'p>, + ciphertext: CffiBuf<'_>, + private_key: &x25519::X25519PrivateKey, + info: Option>, + aad: Option>, + ) -> CryptographyResult> { + let ct_bytes = ciphertext.as_bytes(); + if ct_bytes.len() < KEM_NENC { + return Err(CryptographyError::from(exceptions::InvalidTag::new_err(()))); + } + + let info_bytes = info.map(|b| b.as_bytes().to_vec()).unwrap_or_default(); + let aad_bytes = aad.map(|b| b.as_bytes().to_vec()).unwrap_or_default(); + + let enc = &ct_bytes[..KEM_NENC]; + let ct = &ct_bytes[KEM_NENC..]; + + let shared_secret = self.decap(py, enc, private_key)?; + let (key, base_nonce) = self.key_schedule(py, &shared_secret, &info_bytes)?; + let plaintext = self.aead_decrypt(&key, &base_nonce, ct, &aad_bytes)?; + + Ok(pyo3::types::PyBytes::new(py, &plaintext)) + } +} + +#[pyo3::pymodule(gil_used = false)] +pub(crate) mod hpke { + #[pymodule_export] + use super::{Suite, AEAD, KDF, KEM}; +} diff --git a/src/rust/src/backend/mod.rs b/src/rust/src/backend/mod.rs index aceea4c166c9..a9133cafb8c8 100644 --- a/src/rust/src/backend/mod.rs +++ b/src/rust/src/backend/mod.rs @@ -18,6 +18,7 @@ pub(crate) mod ed25519; pub(crate) mod ed448; pub(crate) mod hashes; pub(crate) mod hmac; +pub(crate) mod hpke; pub(crate) mod kdf; pub(crate) mod keys; pub(crate) mod poly1305; diff --git a/src/rust/src/backend/x25519.rs b/src/rust/src/backend/x25519.rs index 9ee092725aec..2230872b9f4f 100644 --- a/src/rust/src/backend/x25519.rs +++ b/src/rust/src/backend/x25519.rs @@ -60,6 +60,16 @@ fn from_public_bytes(data: &[u8]) -> pyo3::PyResult { Ok(X25519PublicKey { pkey }) } +impl X25519PrivateKey { + // Internal method for use by other Rust modules (e.g., HPKE) + pub(crate) fn private_bytes_raw_internal( + &self, + _py: pyo3::Python<'_>, + ) -> CryptographyResult> { + Ok(self.pkey.raw_private_key()?) + } +} + #[pyo3::pymethods] impl X25519PrivateKey { fn exchange<'p>( @@ -128,6 +138,16 @@ impl X25519PrivateKey { } } +impl X25519PublicKey { + // Internal method for use by other Rust modules (e.g., HPKE) + pub(crate) fn public_bytes_raw_internal( + &self, + _py: pyo3::Python<'_>, + ) -> CryptographyResult> { + Ok(self.pkey.raw_public_key()?) + } +} + #[pyo3::pymethods] impl X25519PublicKey { fn public_bytes_raw<'p>( diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index 32067299f467..dff16a94dc01 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -231,6 +231,8 @@ mod _rust { #[pymodule_export] use crate::backend::hmac::hmac; #[pymodule_export] + use crate::backend::hpke::hpke; + #[pymodule_export] use crate::backend::kdf::kdf; #[pymodule_export] use crate::backend::keys::keys; From 049b0f53ca3610e01b8f17bb448ec55f21a0d3ea Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 28 Jan 2026 23:48:15 +0000 Subject: [PATCH 2/7] Refactor HPKE Rust implementation - Use proper Rust enums with #[pyclass] for KEM, KDF, and AEAD types - Accept Python ABCs for keys using pyo3::PyAny and Python APIs - Remove unnecessary internal helper methods from x25519.rs - Simplify type stub file by removing unnecessary value properties - Update tests to work with pyo3's native type checking error messages https://claude.ai/code/session_01W43m9LudrvqkHKr4BrMa7c --- .../hazmat/bindings/_rust/openssl/hpke.pyi | 6 - src/rust/src/backend/hpke.rs | 302 +++++++----------- src/rust/src/backend/x25519.rs | 20 -- tests/hazmat/primitives/test_hpke.py | 8 +- 4 files changed, 110 insertions(+), 226 deletions(-) diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/hpke.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/hpke.pyi index 609b6587da55..885436a9424c 100644 --- a/src/cryptography/hazmat/bindings/_rust/openssl/hpke.pyi +++ b/src/cryptography/hazmat/bindings/_rust/openssl/hpke.pyi @@ -7,22 +7,16 @@ from cryptography.utils import Buffer class KEM: X25519: KEM - @property - def value(self) -> str: ... def __eq__(self, other: object) -> bool: ... def __hash__(self) -> int: ... class KDF: HKDF_SHA256: KDF - @property - def value(self) -> str: ... def __eq__(self, other: object) -> bool: ... def __hash__(self) -> int: ... class AEAD: AES_128_GCM: AEAD - @property - def value(self) -> str: ... def __eq__(self, other: object) -> bool: ... def __hash__(self) -> int: ... diff --git a/src/rust/src/backend/hpke.rs b/src/rust/src/backend/hpke.rs index 9e32f018bd6f..a957a65cf4b1 100644 --- a/src/rust/src/backend/hpke.rs +++ b/src/rust/src/backend/hpke.rs @@ -3,7 +3,6 @@ // for complete details. use crate::backend::hmac::Hmac; -use crate::backend::x25519; use crate::buf::CffiBuf; use crate::error::{CryptographyError, CryptographyResult}; use crate::exceptions; @@ -27,121 +26,57 @@ const AEAD_NK: usize = 16; const AEAD_NN: usize = 12; const AEAD_NT: usize = 16; -fn int_to_bytes(value: u16, length: usize) -> Vec { - let bytes = value.to_be_bytes(); - if length == 1 { - vec![bytes[1]] - } else { - bytes.to_vec() - } -} - #[allow(clippy::upper_case_acronyms)] -#[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.openssl.hpke")] -pub(crate) struct KEM { - _value: String, +#[pyo3::pyclass( + frozen, + eq, + hash, + module = "cryptography.hazmat.bindings._rust.openssl.hpke" +)] +#[derive(Clone, PartialEq, Eq, Hash)] +pub(crate) enum KEM { + X25519, } #[pyo3::pymethods] -impl KEM { - #[classattr] - #[pyo3(name = "X25519")] - fn x25519() -> KEM { - KEM { - _value: "X25519".to_string(), - } - } - - fn __eq__(&self, other: &KEM) -> bool { - self._value == other._value - } - - fn __hash__(&self) -> u64 { - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; - let mut hasher = DefaultHasher::new(); - self._value.hash(&mut hasher); - hasher.finish() - } - - #[getter] - fn value(&self) -> &str { - &self._value - } -} +impl KEM {} #[allow(clippy::upper_case_acronyms)] -#[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.openssl.hpke")] -pub(crate) struct KDF { - _value: String, +#[allow(non_camel_case_types)] +#[pyo3::pyclass( + frozen, + eq, + hash, + module = "cryptography.hazmat.bindings._rust.openssl.hpke" +)] +#[derive(Clone, PartialEq, Eq, Hash)] +pub(crate) enum KDF { + HKDF_SHA256, } #[pyo3::pymethods] -impl KDF { - #[classattr] - #[pyo3(name = "HKDF_SHA256")] - fn hkdf_sha256() -> KDF { - KDF { - _value: "HKDF_SHA256".to_string(), - } - } - - fn __eq__(&self, other: &KDF) -> bool { - self._value == other._value - } - - fn __hash__(&self) -> u64 { - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; - let mut hasher = DefaultHasher::new(); - self._value.hash(&mut hasher); - hasher.finish() - } - - #[getter] - fn value(&self) -> &str { - &self._value - } -} +impl KDF {} #[allow(clippy::upper_case_acronyms)] -#[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.openssl.hpke")] -pub(crate) struct AEAD { - _value: String, +#[allow(non_camel_case_types)] +#[pyo3::pyclass( + frozen, + eq, + hash, + module = "cryptography.hazmat.bindings._rust.openssl.hpke" +)] +#[derive(Clone, PartialEq, Eq, Hash)] +pub(crate) enum AEAD { + AES_128_GCM, } #[pyo3::pymethods] -impl AEAD { - #[classattr] - #[pyo3(name = "AES_128_GCM")] - fn aes_128_gcm() -> AEAD { - AEAD { - _value: "AES_128_GCM".to_string(), - } - } - - fn __eq__(&self, other: &AEAD) -> bool { - self._value == other._value - } - - fn __hash__(&self) -> u64 { - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; - let mut hasher = DefaultHasher::new(); - self._value.hash(&mut hasher); - hasher.finish() - } - - #[getter] - fn value(&self) -> &str { - &self._value - } -} +impl AEAD {} #[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.openssl.hpke")] pub(crate) struct Suite { - kem_suite_id: Vec, - hpke_suite_id: Vec, + kem_suite_id: [u8; 5], + hpke_suite_id: [u8; 10], } impl Suite { @@ -150,7 +85,7 @@ impl Suite { py: pyo3::Python<'_>, salt: &[u8], ikm: &[u8], - ) -> CryptographyResult> { + ) -> CryptographyResult { let sha256 = types::SHA256.get(py)?.call0()?; let digest_size = sha256 .getattr(pyo3::intern!(py, "digest_size"))? @@ -159,8 +94,7 @@ impl Suite { let salt_bytes = if salt.is_empty() { &default_salt } else { salt }; let mut hmac = Hmac::new_bytes(py, salt_bytes, &sha256)?; hmac.update_bytes(ikm)?; - let result = hmac.finalize_bytes()?; - Ok(result.to_vec()) + hmac.finalize_bytes() } fn hkdf_expand( @@ -207,8 +141,8 @@ impl Suite { salt: &[u8], label: &[u8], ikm: &[u8], - ) -> CryptographyResult> { - let mut labeled_ikm = Vec::new(); + ) -> CryptographyResult { + let mut labeled_ikm = Vec::with_capacity(HPKE_VERSION.len() + 5 + label.len() + ikm.len()); labeled_ikm.extend_from_slice(HPKE_VERSION); labeled_ikm.extend_from_slice(&self.kem_suite_id); labeled_ikm.extend_from_slice(label); @@ -224,8 +158,9 @@ impl Suite { info: &[u8], length: usize, ) -> CryptographyResult> { - let mut labeled_info = Vec::new(); - labeled_info.extend_from_slice(&int_to_bytes(length as u16, 2)); + let mut labeled_info = + Vec::with_capacity(2 + HPKE_VERSION.len() + 5 + label.len() + info.len()); + labeled_info.extend_from_slice(&(length as u16).to_be_bytes()); labeled_info.extend_from_slice(HPKE_VERSION); labeled_info.extend_from_slice(&self.kem_suite_id); labeled_info.extend_from_slice(label); @@ -246,26 +181,30 @@ impl Suite { fn encap( &self, py: pyo3::Python<'_>, - pk_r: &x25519::X25519PublicKey, + pk_r: &pyo3::Bound<'_, pyo3::PyAny>, ) -> CryptographyResult<(Vec, Vec)> { // Generate ephemeral key pair using OpenSSL directly let sk_e_pkey = openssl::pkey::PKey::generate_x25519()?; let pk_e_raw = sk_e_pkey.raw_public_key()?; - // Exchange using the ephemeral private key and recipient's public key - let pk_r_raw = pk_r.public_bytes_raw_internal(py)?; + // Get recipient's public key raw bytes via Python API + let pk_r_bytes = pk_r.call_method0(pyo3::intern!(py, "public_bytes_raw"))?; + let pk_r_raw = pk_r_bytes.extract::<&[u8]>()?; + + // Create recipient public key from raw bytes let pk_r_pkey = - openssl::pkey::PKey::public_key_from_raw_bytes(&pk_r_raw, openssl::pkey::Id::X25519)?; + openssl::pkey::PKey::public_key_from_raw_bytes(pk_r_raw, openssl::pkey::Id::X25519)?; + // Perform ECDH let mut deriver = openssl::derive::Deriver::new(&sk_e_pkey)?; deriver.set_peer(&pk_r_pkey)?; let mut dh = vec![0u8; deriver.len()?]; let n = deriver.derive(&mut dh)?; assert_eq!(n, dh.len()); - let mut kem_context = Vec::new(); - kem_context.extend_from_slice(&pk_e_raw); - kem_context.extend_from_slice(&pk_r_raw); + let mut kem_context = [0u8; 64]; + kem_context[..32].copy_from_slice(&pk_e_raw); + kem_context[32..].copy_from_slice(pk_r_raw); let shared_secret = self.extract_and_expand(py, &dh, &kem_context)?; Ok((shared_secret, pk_e_raw)) } @@ -274,7 +213,7 @@ impl Suite { &self, py: pyo3::Python<'_>, enc: &[u8], - sk_r: &x25519::X25519PrivateKey, + sk_r: &pyo3::Bound<'_, pyo3::PyAny>, ) -> CryptographyResult> { // Reconstruct pk_e from enc let pk_e_pkey = @@ -285,10 +224,13 @@ impl Suite { )) })?; - // Get our private key for ECDH - let sk_r_raw = sk_r.private_bytes_raw_internal(py)?; + // Get our private key raw bytes via Python API + let sk_r_bytes = sk_r.call_method0(pyo3::intern!(py, "private_bytes_raw"))?; + let sk_r_raw = sk_r_bytes.extract::<&[u8]>()?; + + // Reconstruct private key from raw bytes let sk_r_pkey = - openssl::pkey::PKey::private_key_from_raw_bytes(&sk_r_raw, openssl::pkey::Id::X25519)?; + openssl::pkey::PKey::private_key_from_raw_bytes(sk_r_raw, openssl::pkey::Id::X25519)?; // Perform ECDH let mut deriver = openssl::derive::Deriver::new(&sk_r_pkey)?; @@ -300,9 +242,9 @@ impl Suite { // Get our public key let pk_rm = sk_r_pkey.raw_public_key()?; - let mut kem_context = Vec::new(); - kem_context.extend_from_slice(enc); - kem_context.extend_from_slice(&pk_rm); + let mut kem_context = [0u8; 64]; + kem_context[..32].copy_from_slice(enc); + kem_context[32..].copy_from_slice(&pk_rm); self.extract_and_expand(py, &dh, &kem_context) } @@ -312,8 +254,8 @@ impl Suite { salt: &[u8], label: &[u8], ikm: &[u8], - ) -> CryptographyResult> { - let mut labeled_ikm = Vec::new(); + ) -> CryptographyResult { + let mut labeled_ikm = Vec::with_capacity(HPKE_VERSION.len() + 10 + label.len() + ikm.len()); labeled_ikm.extend_from_slice(HPKE_VERSION); labeled_ikm.extend_from_slice(&self.hpke_suite_id); labeled_ikm.extend_from_slice(label); @@ -329,8 +271,9 @@ impl Suite { info: &[u8], length: usize, ) -> CryptographyResult> { - let mut labeled_info = Vec::new(); - labeled_info.extend_from_slice(&int_to_bytes(length as u16, 2)); + let mut labeled_info = + Vec::with_capacity(2 + HPKE_VERSION.len() + 10 + label.len() + info.len()); + labeled_info.extend_from_slice(&(length as u16).to_be_bytes()); labeled_info.extend_from_slice(HPKE_VERSION); labeled_info.extend_from_slice(&self.hpke_suite_id); labeled_info.extend_from_slice(label); @@ -343,7 +286,7 @@ impl Suite { py: pyo3::Python<'_>, shared_secret: &[u8], info: &[u8], - ) -> CryptographyResult<(Vec, Vec)> { + ) -> CryptographyResult<([u8; AEAD_NK], [u8; AEAD_NN])> { let psk_id_hash = self.hpke_labeled_extract(py, b"", b"psk_id_hash", b"")?; let info_hash = self.hpke_labeled_extract(py, b"", b"info_hash", info)?; let mut key_schedule_context = vec![HPKE_MODE_BASE]; @@ -352,10 +295,16 @@ impl Suite { let secret = self.hpke_labeled_extract(py, shared_secret, b"secret", b"")?; - let key = self.hpke_labeled_expand(py, &secret, b"key", &key_schedule_context, AEAD_NK)?; - let base_nonce = + let key_vec = + self.hpke_labeled_expand(py, &secret, b"key", &key_schedule_context, AEAD_NK)?; + let nonce_vec = self.hpke_labeled_expand(py, &secret, b"base_nonce", &key_schedule_context, AEAD_NN)?; + let mut key = [0u8; AEAD_NK]; + let mut base_nonce = [0u8; AEAD_NN]; + key.copy_from_slice(&key_vec); + base_nonce.copy_from_slice(&nonce_vec); + Ok((key, base_nonce)) } @@ -366,16 +315,7 @@ impl Suite { plaintext: &[u8], aad: &[u8], ) -> CryptographyResult> { - let cipher = match key.len() { - 16 => openssl::cipher::Cipher::aes_128_gcm(), - 24 => openssl::cipher::Cipher::aes_192_gcm(), - 32 => openssl::cipher::Cipher::aes_256_gcm(), - _ => { - return Err(CryptographyError::from( - pyo3::exceptions::PyValueError::new_err("Invalid key length"), - )) - } - }; + let cipher = openssl::cipher::Cipher::aes_128_gcm(); let mut ctx = openssl::cipher_ctx::CipherCtx::new()?; ctx.encrypt_init(Some(cipher), Some(key), None)?; @@ -414,16 +354,7 @@ impl Suite { return Err(CryptographyError::from(exceptions::InvalidTag::new_err(()))); } - let cipher = match key.len() { - 16 => openssl::cipher::Cipher::aes_128_gcm(), - 24 => openssl::cipher::Cipher::aes_192_gcm(), - 32 => openssl::cipher::Cipher::aes_256_gcm(), - _ => { - return Err(CryptographyError::from( - pyo3::exceptions::PyValueError::new_err("Invalid key length"), - )) - } - }; + let cipher = openssl::cipher::Cipher::aes_128_gcm(); let ct_len = ciphertext.len() - AEAD_NT; let (ct_data, tag) = ciphertext.split_at(ct_len); @@ -456,39 +387,17 @@ impl Suite { #[pyo3::pymethods] impl Suite { #[new] - fn new( - _py: pyo3::Python<'_>, - kem: &pyo3::Bound<'_, pyo3::PyAny>, - kdf: &pyo3::Bound<'_, pyo3::PyAny>, - aead: &pyo3::Bound<'_, pyo3::PyAny>, - ) -> CryptographyResult { - // Validate types - if !kem.is_instance_of::() { - return Err(CryptographyError::from( - pyo3::exceptions::PyTypeError::new_err("kem must be an instance of KEM"), - )); - } - if !kdf.is_instance_of::() { - return Err(CryptographyError::from( - pyo3::exceptions::PyTypeError::new_err("kdf must be an instance of KDF"), - )); - } - if !aead.is_instance_of::() { - return Err(CryptographyError::from( - pyo3::exceptions::PyTypeError::new_err("aead must be an instance of AEAD"), - )); - } - + fn new(_kem: KEM, _kdf: KDF, _aead: AEAD) -> CryptographyResult { // Build suite IDs - let mut kem_suite_id = Vec::new(); - kem_suite_id.extend_from_slice(b"KEM"); - kem_suite_id.extend_from_slice(&int_to_bytes(KEM_ID, 2)); + let mut kem_suite_id = [0u8; 5]; + kem_suite_id[..3].copy_from_slice(b"KEM"); + kem_suite_id[3..].copy_from_slice(&KEM_ID.to_be_bytes()); - let mut hpke_suite_id = Vec::new(); - hpke_suite_id.extend_from_slice(b"HPKE"); - hpke_suite_id.extend_from_slice(&int_to_bytes(KEM_ID, 2)); - hpke_suite_id.extend_from_slice(&int_to_bytes(KDF_ID, 2)); - hpke_suite_id.extend_from_slice(&int_to_bytes(AEAD_ID, 2)); + let mut hpke_suite_id = [0u8; 10]; + hpke_suite_id[..4].copy_from_slice(b"HPKE"); + hpke_suite_id[4..6].copy_from_slice(&KEM_ID.to_be_bytes()); + hpke_suite_id[6..8].copy_from_slice(&KDF_ID.to_be_bytes()); + hpke_suite_id[8..10].copy_from_slice(&AEAD_ID.to_be_bytes()); Ok(Suite { kem_suite_id, @@ -501,23 +410,27 @@ impl Suite { &self, py: pyo3::Python<'p>, plaintext: CffiBuf<'_>, - public_key: &x25519::X25519PublicKey, + public_key: &pyo3::Bound<'_, pyo3::PyAny>, info: Option>, aad: Option>, ) -> CryptographyResult> { - let info_bytes = info.map(|b| b.as_bytes().to_vec()).unwrap_or_default(); - let aad_bytes = aad.map(|b| b.as_bytes().to_vec()).unwrap_or_default(); + let info_bytes: &[u8] = info.as_ref().map(|b| b.as_bytes()).unwrap_or(b""); + let aad_bytes: &[u8] = aad.as_ref().map(|b| b.as_bytes()).unwrap_or(b""); let (shared_secret, enc) = self.encap(py, public_key)?; - let (key, base_nonce) = self.key_schedule(py, &shared_secret, &info_bytes)?; - let ct = self.aead_encrypt(&key, &base_nonce, plaintext.as_bytes(), &aad_bytes)?; + let (key, base_nonce) = self.key_schedule(py, &shared_secret, info_bytes)?; + let ct = self.aead_encrypt(&key, &base_nonce, plaintext.as_bytes(), aad_bytes)?; // Combine enc + ct - let mut result = Vec::with_capacity(enc.len() + ct.len()); - result.extend_from_slice(&enc); - result.extend_from_slice(&ct); - - Ok(pyo3::types::PyBytes::new(py, &result)) + Ok(pyo3::types::PyBytes::new_with( + py, + enc.len() + ct.len(), + |buf| { + buf[..enc.len()].copy_from_slice(&enc); + buf[enc.len()..].copy_from_slice(&ct); + Ok(()) + }, + )?) } #[pyo3(signature = (ciphertext, private_key, info=None, aad=None))] @@ -525,7 +438,7 @@ impl Suite { &self, py: pyo3::Python<'p>, ciphertext: CffiBuf<'_>, - private_key: &x25519::X25519PrivateKey, + private_key: &pyo3::Bound<'_, pyo3::PyAny>, info: Option>, aad: Option>, ) -> CryptographyResult> { @@ -534,15 +447,14 @@ impl Suite { return Err(CryptographyError::from(exceptions::InvalidTag::new_err(()))); } - let info_bytes = info.map(|b| b.as_bytes().to_vec()).unwrap_or_default(); - let aad_bytes = aad.map(|b| b.as_bytes().to_vec()).unwrap_or_default(); + let info_bytes: &[u8] = info.as_ref().map(|b| b.as_bytes()).unwrap_or(b""); + let aad_bytes: &[u8] = aad.as_ref().map(|b| b.as_bytes()).unwrap_or(b""); - let enc = &ct_bytes[..KEM_NENC]; - let ct = &ct_bytes[KEM_NENC..]; + let (enc, ct) = ct_bytes.split_at(KEM_NENC); let shared_secret = self.decap(py, enc, private_key)?; - let (key, base_nonce) = self.key_schedule(py, &shared_secret, &info_bytes)?; - let plaintext = self.aead_decrypt(&key, &base_nonce, ct, &aad_bytes)?; + let (key, base_nonce) = self.key_schedule(py, &shared_secret, info_bytes)?; + let plaintext = self.aead_decrypt(&key, &base_nonce, ct, aad_bytes)?; Ok(pyo3::types::PyBytes::new(py, &plaintext)) } diff --git a/src/rust/src/backend/x25519.rs b/src/rust/src/backend/x25519.rs index 2230872b9f4f..9ee092725aec 100644 --- a/src/rust/src/backend/x25519.rs +++ b/src/rust/src/backend/x25519.rs @@ -60,16 +60,6 @@ fn from_public_bytes(data: &[u8]) -> pyo3::PyResult { Ok(X25519PublicKey { pkey }) } -impl X25519PrivateKey { - // Internal method for use by other Rust modules (e.g., HPKE) - pub(crate) fn private_bytes_raw_internal( - &self, - _py: pyo3::Python<'_>, - ) -> CryptographyResult> { - Ok(self.pkey.raw_private_key()?) - } -} - #[pyo3::pymethods] impl X25519PrivateKey { fn exchange<'p>( @@ -138,16 +128,6 @@ impl X25519PrivateKey { } } -impl X25519PublicKey { - // Internal method for use by other Rust modules (e.g., HPKE) - pub(crate) fn public_bytes_raw_internal( - &self, - _py: pyo3::Python<'_>, - ) -> CryptographyResult> { - Ok(self.pkey.raw_public_key()?) - } -} - #[pyo3::pymethods] impl X25519PublicKey { fn public_bytes_raw<'p>( diff --git a/tests/hazmat/primitives/test_hpke.py b/tests/hazmat/primitives/test_hpke.py index f257aeb486da..794121c4d9c7 100644 --- a/tests/hazmat/primitives/test_hpke.py +++ b/tests/hazmat/primitives/test_hpke.py @@ -31,17 +31,15 @@ ) class TestHPKE: def test_invalid_kem_type(self): - with pytest.raises(TypeError, match="kem must be an instance of KEM"): + with pytest.raises(TypeError): Suite("not a kem", KDF.HKDF_SHA256, AEAD.AES_128_GCM) # type: ignore[arg-type] def test_invalid_kdf_type(self): - with pytest.raises(TypeError, match="kdf must be an instance of KDF"): + with pytest.raises(TypeError): Suite(KEM.X25519, "not a kdf", AEAD.AES_128_GCM) # type: ignore[arg-type] def test_invalid_aead_type(self): - with pytest.raises( - TypeError, match="aead must be an instance of AEAD" - ): + with pytest.raises(TypeError): Suite(KEM.X25519, KDF.HKDF_SHA256, "not an aead") # type: ignore[arg-type] @pytest.mark.parametrize("kem,kdf,aead", SUPPORTED_SUITES) From 556ee48ed8bd1814f52576fedf5770f9e5f7778d Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 29 Jan 2026 00:02:00 +0000 Subject: [PATCH 3/7] Refactor HPKE to reuse existing kdf and aead code - Organize algorithm constants into modules for easy extension - Remove empty enum impl blocks and unnecessary __eq__/__hash__ from type stubs - Reuse hkdf_extract from kdf.rs (made pub(crate)) - Reuse EvpCipherAead from aead.rs (made pub(crate)) - Use Python API for x25519 key generation and exchange operations - Simplify code by removing duplicated HKDF and AEAD implementations https://claude.ai/code/session_01W43m9LudrvqkHKr4BrMa7c --- .../hazmat/bindings/_rust/openssl/hpke.pyi | 6 - src/rust/src/backend/aead.rs | 10 +- src/rust/src/backend/hpke.rs | 340 ++++++++---------- src/rust/src/backend/kdf.rs | 2 +- 4 files changed, 163 insertions(+), 195 deletions(-) diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/hpke.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/hpke.pyi index 885436a9424c..25904206b3c1 100644 --- a/src/cryptography/hazmat/bindings/_rust/openssl/hpke.pyi +++ b/src/cryptography/hazmat/bindings/_rust/openssl/hpke.pyi @@ -7,18 +7,12 @@ from cryptography.utils import Buffer class KEM: X25519: KEM - def __eq__(self, other: object) -> bool: ... - def __hash__(self) -> int: ... class KDF: HKDF_SHA256: KDF - def __eq__(self, other: object) -> bool: ... - def __hash__(self) -> int: ... class AEAD: AES_128_GCM: AEAD - def __eq__(self, other: object) -> bool: ... - def __hash__(self) -> int: ... class Suite: def __init__(self, kem: KEM, kdf: KDF, aead: AEAD) -> None: ... diff --git a/src/rust/src/backend/aead.rs b/src/rust/src/backend/aead.rs index f94da711b8fe..f1afeb2d5c66 100644 --- a/src/rust/src/backend/aead.rs +++ b/src/rust/src/backend/aead.rs @@ -21,12 +21,12 @@ fn check_length(data: &[u8]) -> CryptographyResult<()> { Ok(()) } -enum Aad<'a> { +pub(crate) enum Aad<'a> { Single(CffiBuf<'a>), List(pyo3::Bound<'a, pyo3::types::PyList>), } -struct EvpCipherAead { +pub(crate) struct EvpCipherAead { base_encryption_ctx: openssl::cipher_ctx::CipherCtx, base_decryption_ctx: openssl::cipher_ctx::CipherCtx, tag_len: usize, @@ -34,7 +34,7 @@ struct EvpCipherAead { } impl EvpCipherAead { - fn new( + pub(crate) fn new( cipher: &openssl::cipher::CipherRef, key: &[u8], tag_len: usize, @@ -127,7 +127,7 @@ impl EvpCipherAead { Ok(()) } - fn encrypt_into( + pub(crate) fn encrypt_into( &self, // We have this arg so we have consistent arguments with encrypt_into in // LazyEvpCipherAead. We can remove it when we remove LazyEvpCipherAead. @@ -192,7 +192,7 @@ impl EvpCipherAead { Ok(()) } - fn decrypt_into( + pub(crate) fn decrypt_into( &self, // We have this arg so we have consistent arguments with decrypt_into in // LazyEvpCipherAead. We can remove it when we remove LazyEvpCipherAead. diff --git a/src/rust/src/backend/hpke.rs b/src/rust/src/backend/hpke.rs index a957a65cf4b1..f74c10019c5a 100644 --- a/src/rust/src/backend/hpke.rs +++ b/src/rust/src/backend/hpke.rs @@ -2,29 +2,34 @@ // 2.0, and the BSD License. See the LICENSE file in the root of this repository // for complete details. +use crate::backend::aead::EvpCipherAead; use crate::backend::hmac::Hmac; +use crate::backend::kdf::hkdf_extract; use crate::buf::CffiBuf; use crate::error::{CryptographyError, CryptographyResult}; use crate::exceptions; -use crate::types; use pyo3::types::{PyAnyMethods, PyBytesMethods}; const HPKE_VERSION: &[u8] = b"HPKE-v1"; const HPKE_MODE_BASE: u8 = 0x00; -// KEM parameters for X25519 (DHKEM(X25519, HKDF-SHA256)) -const KEM_ID: u16 = 0x0020; -const KEM_NSECRET: usize = 32; -const KEM_NENC: usize = 32; +// Algorithm parameters organized by type for easy extension +mod kem_params { + pub const X25519_ID: u16 = 0x0020; + pub const X25519_NSECRET: usize = 32; + pub const X25519_NENC: usize = 32; +} -// KDF parameters for HKDF-SHA256 -const KDF_ID: u16 = 0x0001; +mod kdf_params { + pub const HKDF_SHA256_ID: u16 = 0x0001; +} -// AEAD parameters for AES-128-GCM -const AEAD_ID: u16 = 0x0001; -const AEAD_NK: usize = 16; -const AEAD_NN: usize = 12; -const AEAD_NT: usize = 16; +mod aead_params { + pub const AES_128_GCM_ID: u16 = 0x0001; + pub const AES_128_GCM_NK: usize = 16; + pub const AES_128_GCM_NN: usize = 12; + pub const AES_128_GCM_NT: usize = 16; +} #[allow(clippy::upper_case_acronyms)] #[pyo3::pyclass( @@ -38,9 +43,6 @@ pub(crate) enum KEM { X25519, } -#[pyo3::pymethods] -impl KEM {} - #[allow(clippy::upper_case_acronyms)] #[allow(non_camel_case_types)] #[pyo3::pyclass( @@ -54,9 +56,6 @@ pub(crate) enum KDF { HKDF_SHA256, } -#[pyo3::pymethods] -impl KDF {} - #[allow(clippy::upper_case_acronyms)] #[allow(non_camel_case_types)] #[pyo3::pyclass( @@ -70,9 +69,6 @@ pub(crate) enum AEAD { AES_128_GCM, } -#[pyo3::pymethods] -impl AEAD {} - #[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.openssl.hpke")] pub(crate) struct Suite { kem_suite_id: [u8; 5], @@ -80,23 +76,6 @@ pub(crate) struct Suite { } impl Suite { - fn hkdf_extract( - &self, - py: pyo3::Python<'_>, - salt: &[u8], - ikm: &[u8], - ) -> CryptographyResult { - let sha256 = types::SHA256.get(py)?.call0()?; - let digest_size = sha256 - .getattr(pyo3::intern!(py, "digest_size"))? - .extract::()?; - let default_salt = vec![0u8; digest_size]; - let salt_bytes = if salt.is_empty() { &default_salt } else { salt }; - let mut hmac = Hmac::new_bytes(py, salt_bytes, &sha256)?; - hmac.update_bytes(ikm)?; - hmac.finalize_bytes() - } - fn hkdf_expand( &self, py: pyo3::Python<'_>, @@ -104,12 +83,12 @@ impl Suite { info: &[u8], length: usize, ) -> CryptographyResult> { - let sha256 = types::SHA256.get(py)?.call0()?; - let digest_size = sha256 + let algorithm = crate::types::SHA256.get(py)?.call0()?; + let digest_size = algorithm .getattr(pyo3::intern!(py, "digest_size"))? .extract::()?; - let h_prime = Hmac::new_bytes(py, prk, &sha256)?; + let h_prime = Hmac::new_bytes(py, prk, &algorithm)?; let mut output = vec![0u8; length]; let mut pos = 0usize; @@ -147,7 +126,15 @@ impl Suite { labeled_ikm.extend_from_slice(&self.kem_suite_id); labeled_ikm.extend_from_slice(label); labeled_ikm.extend_from_slice(ikm); - self.hkdf_extract(py, salt, &labeled_ikm) + + let algorithm = crate::types::SHA256.get(py)?.call0()?; + let buf = CffiBuf::from_bytes(py, &labeled_ikm); + let salt_py = if salt.is_empty() { + None + } else { + Some(pyo3::types::PyBytes::new(py, salt).unbind()) + }; + hkdf_extract(py, &algorithm.unbind(), salt_py.as_ref(), &buf) } fn kem_labeled_expand( @@ -175,7 +162,13 @@ impl Suite { kem_context: &[u8], ) -> CryptographyResult> { let eae_prk = self.kem_labeled_extract(py, b"", b"eae_prk", dh)?; - self.kem_labeled_expand(py, &eae_prk, b"shared_secret", kem_context, KEM_NSECRET) + self.kem_labeled_expand( + py, + &eae_prk, + b"shared_secret", + kem_context, + kem_params::X25519_NSECRET, + ) } fn encap( @@ -183,30 +176,33 @@ impl Suite { py: pyo3::Python<'_>, pk_r: &pyo3::Bound<'_, pyo3::PyAny>, ) -> CryptographyResult<(Vec, Vec)> { - // Generate ephemeral key pair using OpenSSL directly - let sk_e_pkey = openssl::pkey::PKey::generate_x25519()?; - let pk_e_raw = sk_e_pkey.raw_public_key()?; - - // Get recipient's public key raw bytes via Python API + // Generate ephemeral key pair using x25519 module + let x25519_mod = py.import(pyo3::intern!( + py, + "cryptography.hazmat.primitives.asymmetric.x25519" + ))?; + let sk_e = x25519_mod + .getattr(pyo3::intern!(py, "X25519PrivateKey"))? + .call_method0(pyo3::intern!(py, "generate"))?; + let pk_e = sk_e.call_method0(pyo3::intern!(py, "public_key"))?; + + // Get ephemeral public key raw bytes + let pk_e_bytes = pk_e.call_method0(pyo3::intern!(py, "public_bytes_raw"))?; + let pk_e_raw = pk_e_bytes.extract::<&[u8]>()?; + + // Get recipient's public key raw bytes let pk_r_bytes = pk_r.call_method0(pyo3::intern!(py, "public_bytes_raw"))?; let pk_r_raw = pk_r_bytes.extract::<&[u8]>()?; - // Create recipient public key from raw bytes - let pk_r_pkey = - openssl::pkey::PKey::public_key_from_raw_bytes(pk_r_raw, openssl::pkey::Id::X25519)?; - - // Perform ECDH - let mut deriver = openssl::derive::Deriver::new(&sk_e_pkey)?; - deriver.set_peer(&pk_r_pkey)?; - let mut dh = vec![0u8; deriver.len()?]; - let n = deriver.derive(&mut dh)?; - assert_eq!(n, dh.len()); + // Perform ECDH via Python API + let dh_result = sk_e.call_method1(pyo3::intern!(py, "exchange"), (pk_r,))?; + let dh = dh_result.extract::<&[u8]>()?; let mut kem_context = [0u8; 64]; - kem_context[..32].copy_from_slice(&pk_e_raw); + kem_context[..32].copy_from_slice(pk_e_raw); kem_context[32..].copy_from_slice(pk_r_raw); - let shared_secret = self.extract_and_expand(py, &dh, &kem_context)?; - Ok((shared_secret, pk_e_raw)) + let shared_secret = self.extract_and_expand(py, dh, &kem_context)?; + Ok((shared_secret, pk_e_raw.to_vec())) } fn decap( @@ -215,37 +211,33 @@ impl Suite { enc: &[u8], sk_r: &pyo3::Bound<'_, pyo3::PyAny>, ) -> CryptographyResult> { - // Reconstruct pk_e from enc - let pk_e_pkey = - openssl::pkey::PKey::public_key_from_raw_bytes(enc, openssl::pkey::Id::X25519) - .map_err(|_| { - CryptographyError::from(pyo3::exceptions::PyValueError::new_err( - "Invalid encapsulated key", - )) - })?; - - // Get our private key raw bytes via Python API - let sk_r_bytes = sk_r.call_method0(pyo3::intern!(py, "private_bytes_raw"))?; - let sk_r_raw = sk_r_bytes.extract::<&[u8]>()?; - - // Reconstruct private key from raw bytes - let sk_r_pkey = - openssl::pkey::PKey::private_key_from_raw_bytes(sk_r_raw, openssl::pkey::Id::X25519)?; - - // Perform ECDH - let mut deriver = openssl::derive::Deriver::new(&sk_r_pkey)?; - deriver.set_peer(&pk_e_pkey)?; - let mut dh = vec![0u8; deriver.len()?]; - let n = deriver.derive(&mut dh)?; - assert_eq!(n, dh.len()); + // Reconstruct pk_e from enc via Python + let x25519_mod = py.import(pyo3::intern!( + py, + "cryptography.hazmat.primitives.asymmetric.x25519" + ))?; + let pk_e = x25519_mod + .getattr(pyo3::intern!(py, "X25519PublicKey"))? + .call_method1(pyo3::intern!(py, "from_public_bytes"), (enc,)) + .map_err(|_| { + CryptographyError::from(pyo3::exceptions::PyValueError::new_err( + "Invalid encapsulated key", + )) + })?; + + // Perform ECDH via Python API + let dh_result = sk_r.call_method1(pyo3::intern!(py, "exchange"), (&pk_e,))?; + let dh = dh_result.extract::<&[u8]>()?; // Get our public key - let pk_rm = sk_r_pkey.raw_public_key()?; + let pk_rm = sk_r.call_method0(pyo3::intern!(py, "public_key"))?; + let pk_rm_bytes = pk_rm.call_method0(pyo3::intern!(py, "public_bytes_raw"))?; + let pk_rm_raw = pk_rm_bytes.extract::<&[u8]>()?; let mut kem_context = [0u8; 64]; kem_context[..32].copy_from_slice(enc); - kem_context[32..].copy_from_slice(&pk_rm); - self.extract_and_expand(py, &dh, &kem_context) + kem_context[32..].copy_from_slice(pk_rm_raw); + self.extract_and_expand(py, dh, &kem_context) } fn hpke_labeled_extract( @@ -260,7 +252,15 @@ impl Suite { labeled_ikm.extend_from_slice(&self.hpke_suite_id); labeled_ikm.extend_from_slice(label); labeled_ikm.extend_from_slice(ikm); - self.hkdf_extract(py, salt, &labeled_ikm) + + let algorithm = crate::types::SHA256.get(py)?.call0()?; + let buf = CffiBuf::from_bytes(py, &labeled_ikm); + let salt_py = if salt.is_empty() { + None + } else { + Some(pyo3::types::PyBytes::new(py, salt).unbind()) + }; + hkdf_extract(py, &algorithm.unbind(), salt_py.as_ref(), &buf) } fn hpke_labeled_expand( @@ -286,7 +286,10 @@ impl Suite { py: pyo3::Python<'_>, shared_secret: &[u8], info: &[u8], - ) -> CryptographyResult<([u8; AEAD_NK], [u8; AEAD_NN])> { + ) -> CryptographyResult<( + [u8; aead_params::AES_128_GCM_NK], + [u8; aead_params::AES_128_GCM_NN], + )> { let psk_id_hash = self.hpke_labeled_extract(py, b"", b"psk_id_hash", b"")?; let info_hash = self.hpke_labeled_extract(py, b"", b"info_hash", info)?; let mut key_schedule_context = vec![HPKE_MODE_BASE]; @@ -295,93 +298,28 @@ impl Suite { let secret = self.hpke_labeled_extract(py, shared_secret, b"secret", b"")?; - let key_vec = - self.hpke_labeled_expand(py, &secret, b"key", &key_schedule_context, AEAD_NK)?; - let nonce_vec = - self.hpke_labeled_expand(py, &secret, b"base_nonce", &key_schedule_context, AEAD_NN)?; - - let mut key = [0u8; AEAD_NK]; - let mut base_nonce = [0u8; AEAD_NN]; + let key_vec = self.hpke_labeled_expand( + py, + &secret, + b"key", + &key_schedule_context, + aead_params::AES_128_GCM_NK, + )?; + let nonce_vec = self.hpke_labeled_expand( + py, + &secret, + b"base_nonce", + &key_schedule_context, + aead_params::AES_128_GCM_NN, + )?; + + let mut key = [0u8; aead_params::AES_128_GCM_NK]; + let mut base_nonce = [0u8; aead_params::AES_128_GCM_NN]; key.copy_from_slice(&key_vec); base_nonce.copy_from_slice(&nonce_vec); Ok((key, base_nonce)) } - - fn aead_encrypt( - &self, - key: &[u8], - nonce: &[u8], - plaintext: &[u8], - aad: &[u8], - ) -> CryptographyResult> { - let cipher = openssl::cipher::Cipher::aes_128_gcm(); - - let mut ctx = openssl::cipher_ctx::CipherCtx::new()?; - ctx.encrypt_init(Some(cipher), Some(key), None)?; - ctx.set_iv_length(nonce.len())?; - ctx.encrypt_init(None, None, Some(nonce))?; - - // Process AAD - if !aad.is_empty() { - ctx.cipher_update(aad, None)?; - } - - // Encrypt plaintext - let mut ciphertext = vec![0u8; plaintext.len() + AEAD_NT]; - let n = ctx.cipher_update(plaintext, Some(&mut ciphertext[..plaintext.len()]))?; - assert_eq!(n, plaintext.len()); - - let mut final_block = [0u8; 0]; - let n = ctx.cipher_final(&mut final_block)?; - assert_eq!(n, 0); - - // Get tag - ctx.tag(&mut ciphertext[plaintext.len()..]) - .map_err(CryptographyError::from)?; - - Ok(ciphertext) - } - - fn aead_decrypt( - &self, - key: &[u8], - nonce: &[u8], - ciphertext: &[u8], - aad: &[u8], - ) -> CryptographyResult> { - if ciphertext.len() < AEAD_NT { - return Err(CryptographyError::from(exceptions::InvalidTag::new_err(()))); - } - - let cipher = openssl::cipher::Cipher::aes_128_gcm(); - - let ct_len = ciphertext.len() - AEAD_NT; - let (ct_data, tag) = ciphertext.split_at(ct_len); - - let mut ctx = openssl::cipher_ctx::CipherCtx::new()?; - ctx.decrypt_init(Some(cipher), Some(key), None)?; - ctx.set_iv_length(nonce.len())?; - ctx.decrypt_init(None, None, Some(nonce))?; - ctx.set_tag(tag)?; - - // Process AAD - if !aad.is_empty() { - ctx.cipher_update(aad, None)?; - } - - // Decrypt ciphertext - let mut plaintext = vec![0u8; ct_len]; - let n = ctx - .cipher_update(ct_data, Some(&mut plaintext)) - .map_err(|_| exceptions::InvalidTag::new_err(()))?; - assert_eq!(n, ct_len); - - ctx.cipher_final(&mut []) - .map_err(|_| exceptions::InvalidTag::new_err(()))?; - - Ok(plaintext) - } } #[pyo3::pymethods] @@ -391,13 +329,13 @@ impl Suite { // Build suite IDs let mut kem_suite_id = [0u8; 5]; kem_suite_id[..3].copy_from_slice(b"KEM"); - kem_suite_id[3..].copy_from_slice(&KEM_ID.to_be_bytes()); + kem_suite_id[3..].copy_from_slice(&kem_params::X25519_ID.to_be_bytes()); let mut hpke_suite_id = [0u8; 10]; hpke_suite_id[..4].copy_from_slice(b"HPKE"); - hpke_suite_id[4..6].copy_from_slice(&KEM_ID.to_be_bytes()); - hpke_suite_id[6..8].copy_from_slice(&KDF_ID.to_be_bytes()); - hpke_suite_id[8..10].copy_from_slice(&AEAD_ID.to_be_bytes()); + hpke_suite_id[4..6].copy_from_slice(&kem_params::X25519_ID.to_be_bytes()); + hpke_suite_id[6..8].copy_from_slice(&kdf_params::HKDF_SHA256_ID.to_be_bytes()); + hpke_suite_id[8..10].copy_from_slice(&aead_params::AES_128_GCM_ID.to_be_bytes()); Ok(Suite { kem_suite_id, @@ -419,15 +357,33 @@ impl Suite { let (shared_secret, enc) = self.encap(py, public_key)?; let (key, base_nonce) = self.key_schedule(py, &shared_secret, info_bytes)?; - let ct = self.aead_encrypt(&key, &base_nonce, plaintext.as_bytes(), aad_bytes)?; - // Combine enc + ct + // Create AEAD with the derived key + let cipher = openssl::cipher::Cipher::aes_128_gcm(); + let aead = EvpCipherAead::new(cipher, &key, aead_params::AES_128_GCM_NT, false)?; + + let pt_bytes = plaintext.as_bytes(); + let ct_len = pt_bytes.len() + aead_params::AES_128_GCM_NT; + Ok(pyo3::types::PyBytes::new_with( py, - enc.len() + ct.len(), + enc.len() + ct_len, |buf| { buf[..enc.len()].copy_from_slice(&enc); - buf[enc.len()..].copy_from_slice(&ct); + let aad_opt = if aad_bytes.is_empty() { + None + } else { + Some(crate::backend::aead::Aad::Single(CffiBuf::from_bytes( + py, aad_bytes, + ))) + }; + aead.encrypt_into( + py, + pt_bytes, + aad_opt, + Some(&base_nonce), + &mut buf[enc.len()..], + )?; Ok(()) }, )?) @@ -443,20 +399,38 @@ impl Suite { aad: Option>, ) -> CryptographyResult> { let ct_bytes = ciphertext.as_bytes(); - if ct_bytes.len() < KEM_NENC { + if ct_bytes.len() < kem_params::X25519_NENC { return Err(CryptographyError::from(exceptions::InvalidTag::new_err(()))); } let info_bytes: &[u8] = info.as_ref().map(|b| b.as_bytes()).unwrap_or(b""); let aad_bytes: &[u8] = aad.as_ref().map(|b| b.as_bytes()).unwrap_or(b""); - let (enc, ct) = ct_bytes.split_at(KEM_NENC); + let (enc, ct) = ct_bytes.split_at(kem_params::X25519_NENC); let shared_secret = self.decap(py, enc, private_key)?; let (key, base_nonce) = self.key_schedule(py, &shared_secret, info_bytes)?; - let plaintext = self.aead_decrypt(&key, &base_nonce, ct, aad_bytes)?; - Ok(pyo3::types::PyBytes::new(py, &plaintext)) + // Create AEAD with the derived key + let cipher = openssl::cipher::Cipher::aes_128_gcm(); + let aead = EvpCipherAead::new(cipher, &key, aead_params::AES_128_GCM_NT, false)?; + + if ct.len() < aead_params::AES_128_GCM_NT { + return Err(CryptographyError::from(exceptions::InvalidTag::new_err(()))); + } + let pt_len = ct.len() - aead_params::AES_128_GCM_NT; + + Ok(pyo3::types::PyBytes::new_with(py, pt_len, |buf| { + let aad_opt = if aad_bytes.is_empty() { + None + } else { + Some(crate::backend::aead::Aad::Single(CffiBuf::from_bytes( + py, aad_bytes, + ))) + }; + aead.decrypt_into(py, ct, aad_opt, Some(&base_nonce), buf)?; + Ok(()) + })?) } } diff --git a/src/rust/src/backend/kdf.rs b/src/rust/src/backend/kdf.rs index 3c1838097a9d..e3f168225ed3 100644 --- a/src/rust/src/backend/kdf.rs +++ b/src/rust/src/backend/kdf.rs @@ -954,7 +954,7 @@ struct Hkdf { used: bool, } -fn hkdf_extract( +pub(crate) fn hkdf_extract( py: pyo3::Python<'_>, algorithm: &pyo3::Py, salt: Option<&pyo3::Py>, From bc2a348f3f159dbec2edfcdf4e55d4208bca8036 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 29 Jan 2026 00:12:40 +0000 Subject: [PATCH 4/7] Use LazyPyImport pattern and Python types for HPKE - Add LazyPyImport entries for X25519_PRIVATE_KEY, X25519_PUBLIC_KEY, HKDF_EXPAND, and AESGCM to types.rs - Replace manual HKDF expand implementation with Python HKDFExpand class - Use Python AESGCM for encryption/decryption instead of EvpCipherAead - Use LazyPyImport for X25519 key operations https://claude.ai/code/session_01W43m9LudrvqkHKr4BrMa7c --- src/rust/src/backend/hpke.rs | 155 +++++++++++++++-------------------- src/rust/src/types.rs | 15 ++++ 2 files changed, 82 insertions(+), 88 deletions(-) diff --git a/src/rust/src/backend/hpke.rs b/src/rust/src/backend/hpke.rs index f74c10019c5a..cbad1e578410 100644 --- a/src/rust/src/backend/hpke.rs +++ b/src/rust/src/backend/hpke.rs @@ -2,13 +2,12 @@ // 2.0, and the BSD License. See the LICENSE file in the root of this repository // for complete details. -use crate::backend::aead::EvpCipherAead; -use crate::backend::hmac::Hmac; use crate::backend::kdf::hkdf_extract; use crate::buf::CffiBuf; use crate::error::{CryptographyError, CryptographyResult}; use crate::exceptions; -use pyo3::types::{PyAnyMethods, PyBytesMethods}; +use crate::types; +use pyo3::types::PyAnyMethods; const HPKE_VERSION: &[u8] = b"HPKE-v1"; const HPKE_MODE_BASE: u8 = 0x00; @@ -28,7 +27,6 @@ mod aead_params { pub const AES_128_GCM_ID: u16 = 0x0001; pub const AES_128_GCM_NK: usize = 16; pub const AES_128_GCM_NN: usize = 12; - pub const AES_128_GCM_NT: usize = 16; } #[allow(clippy::upper_case_acronyms)] @@ -83,35 +81,17 @@ impl Suite { info: &[u8], length: usize, ) -> CryptographyResult> { - let algorithm = crate::types::SHA256.get(py)?.call0()?; - let digest_size = algorithm - .getattr(pyo3::intern!(py, "digest_size"))? - .extract::()?; - - let h_prime = Hmac::new_bytes(py, prk, &algorithm)?; - - let mut output = vec![0u8; length]; - let mut pos = 0usize; - let mut counter = 0u8; - - while pos < length { - counter += 1; - let mut h = h_prime.copy(py)?; - - let start = pos.saturating_sub(digest_size); - h.update_bytes(&output[start..pos])?; - h.update_bytes(info)?; - h.update_bytes(&[counter])?; - - let block = h.finalize(py)?; - let block_bytes = block.as_bytes(); - - let copy_len = (length - pos).min(digest_size); - output[pos..pos + copy_len].copy_from_slice(&block_bytes[..copy_len]); - pos += copy_len; - } - - Ok(output) + let algorithm = types::SHA256.get(py)?.call0()?; + let hkdf_expand = types::HKDF_EXPAND.get(py)?.call1(( + &algorithm, + length, + pyo3::types::PyBytes::new(py, info), + ))?; + let result = hkdf_expand.call_method1( + pyo3::intern!(py, "derive"), + (pyo3::types::PyBytes::new(py, prk),), + )?; + Ok(result.extract::>()?) } fn kem_labeled_extract( @@ -127,7 +107,7 @@ impl Suite { labeled_ikm.extend_from_slice(label); labeled_ikm.extend_from_slice(ikm); - let algorithm = crate::types::SHA256.get(py)?.call0()?; + let algorithm = types::SHA256.get(py)?.call0()?; let buf = CffiBuf::from_bytes(py, &labeled_ikm); let salt_py = if salt.is_empty() { None @@ -176,13 +156,9 @@ impl Suite { py: pyo3::Python<'_>, pk_r: &pyo3::Bound<'_, pyo3::PyAny>, ) -> CryptographyResult<(Vec, Vec)> { - // Generate ephemeral key pair using x25519 module - let x25519_mod = py.import(pyo3::intern!( - py, - "cryptography.hazmat.primitives.asymmetric.x25519" - ))?; - let sk_e = x25519_mod - .getattr(pyo3::intern!(py, "X25519PrivateKey"))? + // Generate ephemeral key pair + let sk_e = types::X25519_PRIVATE_KEY + .get(py)? .call_method0(pyo3::intern!(py, "generate"))?; let pk_e = sk_e.call_method0(pyo3::intern!(py, "public_key"))?; @@ -211,13 +187,9 @@ impl Suite { enc: &[u8], sk_r: &pyo3::Bound<'_, pyo3::PyAny>, ) -> CryptographyResult> { - // Reconstruct pk_e from enc via Python - let x25519_mod = py.import(pyo3::intern!( - py, - "cryptography.hazmat.primitives.asymmetric.x25519" - ))?; - let pk_e = x25519_mod - .getattr(pyo3::intern!(py, "X25519PublicKey"))? + // Reconstruct pk_e from enc + let pk_e = types::X25519_PUBLIC_KEY + .get(py)? .call_method1(pyo3::intern!(py, "from_public_bytes"), (enc,)) .map_err(|_| { CryptographyError::from(pyo3::exceptions::PyValueError::new_err( @@ -253,7 +225,7 @@ impl Suite { labeled_ikm.extend_from_slice(label); labeled_ikm.extend_from_slice(ikm); - let algorithm = crate::types::SHA256.get(py)?.call0()?; + let algorithm = types::SHA256.get(py)?.call0()?; let buf = CffiBuf::from_bytes(py, &labeled_ikm); let salt_py = if salt.is_empty() { None @@ -358,32 +330,35 @@ impl Suite { let (shared_secret, enc) = self.encap(py, public_key)?; let (key, base_nonce) = self.key_schedule(py, &shared_secret, info_bytes)?; - // Create AEAD with the derived key - let cipher = openssl::cipher::Cipher::aes_128_gcm(); - let aead = EvpCipherAead::new(cipher, &key, aead_params::AES_128_GCM_NT, false)?; + // Create AESGCM with the derived key + let aesgcm = types::AESGCM + .get(py)? + .call1((pyo3::types::PyBytes::new(py, &key),))?; + + // Encrypt using AESGCM + let aad_arg = if aad_bytes.is_empty() { + py.None().into_bound(py) + } else { + pyo3::types::PyBytes::new(py, aad_bytes).into_any() + }; - let pt_bytes = plaintext.as_bytes(); - let ct_len = pt_bytes.len() + aead_params::AES_128_GCM_NT; + let ct = aesgcm.call_method1( + pyo3::intern!(py, "encrypt"), + ( + pyo3::types::PyBytes::new(py, &base_nonce), + pyo3::types::PyBytes::new(py, plaintext.as_bytes()), + aad_arg, + ), + )?; + let ct_bytes = ct.extract::<&[u8]>()?; + // Combine enc + ct Ok(pyo3::types::PyBytes::new_with( py, - enc.len() + ct_len, + enc.len() + ct_bytes.len(), |buf| { buf[..enc.len()].copy_from_slice(&enc); - let aad_opt = if aad_bytes.is_empty() { - None - } else { - Some(crate::backend::aead::Aad::Single(CffiBuf::from_bytes( - py, aad_bytes, - ))) - }; - aead.encrypt_into( - py, - pt_bytes, - aad_opt, - Some(&base_nonce), - &mut buf[enc.len()..], - )?; + buf[enc.len()..].copy_from_slice(ct_bytes); Ok(()) }, )?) @@ -411,26 +386,30 @@ impl Suite { let shared_secret = self.decap(py, enc, private_key)?; let (key, base_nonce) = self.key_schedule(py, &shared_secret, info_bytes)?; - // Create AEAD with the derived key - let cipher = openssl::cipher::Cipher::aes_128_gcm(); - let aead = EvpCipherAead::new(cipher, &key, aead_params::AES_128_GCM_NT, false)?; + // Create AESGCM with the derived key + let aesgcm = types::AESGCM + .get(py)? + .call1((pyo3::types::PyBytes::new(py, &key),))?; - if ct.len() < aead_params::AES_128_GCM_NT { - return Err(CryptographyError::from(exceptions::InvalidTag::new_err(()))); - } - let pt_len = ct.len() - aead_params::AES_128_GCM_NT; - - Ok(pyo3::types::PyBytes::new_with(py, pt_len, |buf| { - let aad_opt = if aad_bytes.is_empty() { - None - } else { - Some(crate::backend::aead::Aad::Single(CffiBuf::from_bytes( - py, aad_bytes, - ))) - }; - aead.decrypt_into(py, ct, aad_opt, Some(&base_nonce), buf)?; - Ok(()) - })?) + // Decrypt using AESGCM + let aad_arg = if aad_bytes.is_empty() { + py.None().into_bound(py) + } else { + pyo3::types::PyBytes::new(py, aad_bytes).into_any() + }; + + let pt = aesgcm + .call_method1( + pyo3::intern!(py, "decrypt"), + ( + pyo3::types::PyBytes::new(py, &base_nonce), + pyo3::types::PyBytes::new(py, ct), + aad_arg, + ), + ) + .map_err(|_| exceptions::InvalidTag::new_err(()))?; + + Ok(pt.extract::>()?) } } diff --git a/src/rust/src/types.rs b/src/rust/src/types.rs index 3a7b02c6d69e..8a128c6c1da9 100644 --- a/src/rust/src/types.rs +++ b/src/rust/src/types.rs @@ -486,6 +486,15 @@ pub static ED448_PUBLIC_KEY: LazyPyImport = LazyPyImport::new( &["Ed448PublicKey"], ); +pub static X25519_PRIVATE_KEY: LazyPyImport = LazyPyImport::new( + "cryptography.hazmat.primitives.asymmetric.x25519", + &["X25519PrivateKey"], +); +pub static X25519_PUBLIC_KEY: LazyPyImport = LazyPyImport::new( + "cryptography.hazmat.primitives.asymmetric.x25519", + &["X25519PublicKey"], +); + pub static DSA_PRIVATE_KEY: LazyPyImport = LazyPyImport::new( "cryptography.hazmat.primitives.asymmetric.dsa", &["DSAPrivateKey"], @@ -613,6 +622,12 @@ pub static KBKDF_COUNTER_LOCATION: LazyPyImport = LazyPyImport::new( &["CounterLocation"], ); +pub static HKDF_EXPAND: LazyPyImport = + LazyPyImport::new("cryptography.hazmat.primitives.kdf.hkdf", &["HKDFExpand"]); + +pub static AESGCM: LazyPyImport = + LazyPyImport::new("cryptography.hazmat.primitives.ciphers.aead", &["AESGCM"]); + #[cfg(test)] mod tests { use super::LazyPyImport; From d137e1627cbcd3cb11fd3f99d1208214213001e2 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Wed, 28 Jan 2026 19:45:00 -0500 Subject: [PATCH 5/7] clean up style --- src/rust/src/backend/aead.rs | 11 +- src/rust/src/backend/hpke.rs | 193 ++++++++++++--------------------- src/rust/src/backend/kdf.rs | 33 +++--- src/rust/src/backend/x25519.rs | 4 +- src/rust/src/types.rs | 15 --- 5 files changed, 95 insertions(+), 161 deletions(-) diff --git a/src/rust/src/backend/aead.rs b/src/rust/src/backend/aead.rs index f1afeb2d5c66..a5fc3478d651 100644 --- a/src/rust/src/backend/aead.rs +++ b/src/rust/src/backend/aead.rs @@ -626,7 +626,7 @@ impl ChaCha20Poly1305 { name = "AESGCM" )] // NO-COVERAGE-END -struct AesGcm { +pub(crate) struct AesGcm { #[cfg(any( CRYPTOGRAPHY_OPENSSL_320_OR_GREATER, CRYPTOGRAPHY_IS_LIBRESSL, @@ -647,7 +647,10 @@ struct AesGcm { #[pyo3::pymethods] impl AesGcm { #[new] - fn new(py: pyo3::Python<'_>, key: pyo3::Py) -> CryptographyResult { + pub(crate) fn new( + py: pyo3::Python<'_>, + key: pyo3::Py, + ) -> CryptographyResult { let key_buf = key.extract::>(py)?; let cipher = match key_buf.as_bytes().len() { 16 => openssl::cipher::Cipher::aes_128_gcm(), @@ -696,7 +699,7 @@ impl AesGcm { } #[pyo3(signature = (nonce, data, associated_data))] - fn encrypt<'p>( + pub(crate) fn encrypt<'p>( &self, py: pyo3::Python<'p>, nonce: CffiBuf<'_>, @@ -754,7 +757,7 @@ impl AesGcm { } #[pyo3(signature = (nonce, data, associated_data))] - fn decrypt<'p>( + pub(crate) fn decrypt<'p>( &self, py: pyo3::Python<'p>, nonce: CffiBuf<'_>, diff --git a/src/rust/src/backend/hpke.rs b/src/rust/src/backend/hpke.rs index cbad1e578410..7bed017d4591 100644 --- a/src/rust/src/backend/hpke.rs +++ b/src/rust/src/backend/hpke.rs @@ -2,17 +2,18 @@ // 2.0, and the BSD License. See the LICENSE file in the root of this repository // for complete details. -use crate::backend::kdf::hkdf_extract; +use crate::backend::aead::AesGcm; +use crate::backend::kdf::{hkdf_extract, HkdfExpand}; +use crate::backend::x25519; use crate::buf::CffiBuf; use crate::error::{CryptographyError, CryptographyResult}; use crate::exceptions; use crate::types; -use pyo3::types::PyAnyMethods; +use pyo3::types::{PyAnyMethods, PyBytesMethods}; const HPKE_VERSION: &[u8] = b"HPKE-v1"; const HPKE_MODE_BASE: u8 = 0x00; -// Algorithm parameters organized by type for easy extension mod kem_params { pub const X25519_ID: u16 = 0x0020; pub const X25519_NSECRET: usize = 32; @@ -74,30 +75,28 @@ pub(crate) struct Suite { } impl Suite { - fn hkdf_expand( + fn hkdf_expand<'p>( &self, - py: pyo3::Python<'_>, + py: pyo3::Python<'p>, prk: &[u8], info: &[u8], length: usize, - ) -> CryptographyResult> { + ) -> CryptographyResult> { let algorithm = types::SHA256.get(py)?.call0()?; - let hkdf_expand = types::HKDF_EXPAND.get(py)?.call1(( - &algorithm, + + let mut hkdf_expand = HkdfExpand::new( + py, + algorithm.unbind(), length, - pyo3::types::PyBytes::new(py, info), - ))?; - let result = hkdf_expand.call_method1( - pyo3::intern!(py, "derive"), - (pyo3::types::PyBytes::new(py, prk),), + Some(pyo3::types::PyBytes::new(py, info).unbind()), + None, )?; - Ok(result.extract::>()?) + hkdf_expand.derive(py, CffiBuf::from_bytes(py, prk)) } fn kem_labeled_extract( &self, py: pyo3::Python<'_>, - salt: &[u8], label: &[u8], ikm: &[u8], ) -> CryptographyResult { @@ -109,22 +108,17 @@ impl Suite { let algorithm = types::SHA256.get(py)?.call0()?; let buf = CffiBuf::from_bytes(py, &labeled_ikm); - let salt_py = if salt.is_empty() { - None - } else { - Some(pyo3::types::PyBytes::new(py, salt).unbind()) - }; - hkdf_extract(py, &algorithm.unbind(), salt_py.as_ref(), &buf) + hkdf_extract(py, &algorithm.unbind(), None, &buf) } - fn kem_labeled_expand( + fn kem_labeled_expand<'p>( &self, - py: pyo3::Python<'_>, + py: pyo3::Python<'p>, prk: &[u8], label: &[u8], info: &[u8], length: usize, - ) -> CryptographyResult> { + ) -> CryptographyResult> { let mut labeled_info = Vec::with_capacity(2 + HPKE_VERSION.len() + 5 + label.len() + info.len()); labeled_info.extend_from_slice(&(length as u16).to_be_bytes()); @@ -135,13 +129,13 @@ impl Suite { self.hkdf_expand(py, prk, &labeled_info, length) } - fn extract_and_expand( + fn extract_and_expand<'p>( &self, - py: pyo3::Python<'_>, + py: pyo3::Python<'p>, dh: &[u8], kem_context: &[u8], - ) -> CryptographyResult> { - let eae_prk = self.kem_labeled_extract(py, b"", b"eae_prk", dh)?; + ) -> CryptographyResult> { + let eae_prk = self.kem_labeled_extract(py, b"eae_prk", dh)?; self.kem_labeled_expand( py, &eae_prk, @@ -151,57 +145,46 @@ impl Suite { ) } - fn encap( + fn encap<'p>( &self, - py: pyo3::Python<'_>, + py: pyo3::Python<'p>, pk_r: &pyo3::Bound<'_, pyo3::PyAny>, - ) -> CryptographyResult<(Vec, Vec)> { - // Generate ephemeral key pair - let sk_e = types::X25519_PRIVATE_KEY - .get(py)? - .call_method0(pyo3::intern!(py, "generate"))?; + ) -> CryptographyResult<( + pyo3::Bound<'p, pyo3::types::PyBytes>, + pyo3::Bound<'p, pyo3::types::PyBytes>, + )> { + let sk_e = pyo3::Bound::new(py, x25519::generate_key()?)?; let pk_e = sk_e.call_method0(pyo3::intern!(py, "public_key"))?; - // Get ephemeral public key raw bytes - let pk_e_bytes = pk_e.call_method0(pyo3::intern!(py, "public_bytes_raw"))?; - let pk_e_raw = pk_e_bytes.extract::<&[u8]>()?; + let pk_e_bytes: pyo3::Bound<'p, pyo3::types::PyBytes> = pk_e + .call_method0(pyo3::intern!(py, "public_bytes_raw"))? + .extract()?; - // Get recipient's public key raw bytes let pk_r_bytes = pk_r.call_method0(pyo3::intern!(py, "public_bytes_raw"))?; let pk_r_raw = pk_r_bytes.extract::<&[u8]>()?; - // Perform ECDH via Python API let dh_result = sk_e.call_method1(pyo3::intern!(py, "exchange"), (pk_r,))?; let dh = dh_result.extract::<&[u8]>()?; let mut kem_context = [0u8; 64]; - kem_context[..32].copy_from_slice(pk_e_raw); + kem_context[..32].copy_from_slice(pk_e_bytes.as_bytes()); kem_context[32..].copy_from_slice(pk_r_raw); let shared_secret = self.extract_and_expand(py, dh, &kem_context)?; - Ok((shared_secret, pk_e_raw.to_vec())) + Ok((shared_secret, pk_e_bytes)) } - fn decap( + fn decap<'p>( &self, - py: pyo3::Python<'_>, + py: pyo3::Python<'p>, enc: &[u8], sk_r: &pyo3::Bound<'_, pyo3::PyAny>, - ) -> CryptographyResult> { + ) -> CryptographyResult> { // Reconstruct pk_e from enc - let pk_e = types::X25519_PUBLIC_KEY - .get(py)? - .call_method1(pyo3::intern!(py, "from_public_bytes"), (enc,)) - .map_err(|_| { - CryptographyError::from(pyo3::exceptions::PyValueError::new_err( - "Invalid encapsulated key", - )) - })?; - - // Perform ECDH via Python API + let pk_e = pyo3::Bound::new(py, x25519::from_public_bytes(enc)?)?; + let dh_result = sk_r.call_method1(pyo3::intern!(py, "exchange"), (&pk_e,))?; let dh = dh_result.extract::<&[u8]>()?; - // Get our public key let pk_rm = sk_r.call_method0(pyo3::intern!(py, "public_key"))?; let pk_rm_bytes = pk_rm.call_method0(pyo3::intern!(py, "public_bytes_raw"))?; let pk_rm_raw = pk_rm_bytes.extract::<&[u8]>()?; @@ -215,7 +198,7 @@ impl Suite { fn hpke_labeled_extract( &self, py: pyo3::Python<'_>, - salt: &[u8], + salt: Option<&[u8]>, label: &[u8], ikm: &[u8], ) -> CryptographyResult { @@ -227,22 +210,17 @@ impl Suite { let algorithm = types::SHA256.get(py)?.call0()?; let buf = CffiBuf::from_bytes(py, &labeled_ikm); - let salt_py = if salt.is_empty() { - None - } else { - Some(pyo3::types::PyBytes::new(py, salt).unbind()) - }; - hkdf_extract(py, &algorithm.unbind(), salt_py.as_ref(), &buf) + hkdf_extract(py, &algorithm.unbind(), salt, &buf) } - fn hpke_labeled_expand( + fn hpke_labeled_expand<'p>( &self, - py: pyo3::Python<'_>, + py: pyo3::Python<'p>, prk: &[u8], label: &[u8], info: &[u8], length: usize, - ) -> CryptographyResult> { + ) -> CryptographyResult> { let mut labeled_info = Vec::with_capacity(2 + HPKE_VERSION.len() + 10 + label.len() + info.len()); labeled_info.extend_from_slice(&(length as u16).to_be_bytes()); @@ -262,13 +240,13 @@ impl Suite { [u8; aead_params::AES_128_GCM_NK], [u8; aead_params::AES_128_GCM_NN], )> { - let psk_id_hash = self.hpke_labeled_extract(py, b"", b"psk_id_hash", b"")?; - let info_hash = self.hpke_labeled_extract(py, b"", b"info_hash", info)?; + let psk_id_hash = self.hpke_labeled_extract(py, None, b"psk_id_hash", b"")?; + let info_hash = self.hpke_labeled_extract(py, None, b"info_hash", info)?; let mut key_schedule_context = vec![HPKE_MODE_BASE]; key_schedule_context.extend_from_slice(&psk_id_hash); key_schedule_context.extend_from_slice(&info_hash); - let secret = self.hpke_labeled_extract(py, shared_secret, b"secret", b"")?; + let secret = self.hpke_labeled_extract(py, Some(shared_secret), b"secret", b"")?; let key_vec = self.hpke_labeled_expand( py, @@ -287,8 +265,8 @@ impl Suite { let mut key = [0u8; aead_params::AES_128_GCM_NK]; let mut base_nonce = [0u8; aead_params::AES_128_GCM_NN]; - key.copy_from_slice(&key_vec); - base_nonce.copy_from_slice(&nonce_vec); + key.copy_from_slice(key_vec.as_bytes()); + base_nonce.copy_from_slice(nonce_vec.as_bytes()); Ok((key, base_nonce)) } @@ -325,40 +303,21 @@ impl Suite { aad: Option>, ) -> CryptographyResult> { let info_bytes: &[u8] = info.as_ref().map(|b| b.as_bytes()).unwrap_or(b""); - let aad_bytes: &[u8] = aad.as_ref().map(|b| b.as_bytes()).unwrap_or(b""); let (shared_secret, enc) = self.encap(py, public_key)?; - let (key, base_nonce) = self.key_schedule(py, &shared_secret, info_bytes)?; - - // Create AESGCM with the derived key - let aesgcm = types::AESGCM - .get(py)? - .call1((pyo3::types::PyBytes::new(py, &key),))?; - - // Encrypt using AESGCM - let aad_arg = if aad_bytes.is_empty() { - py.None().into_bound(py) - } else { - pyo3::types::PyBytes::new(py, aad_bytes).into_any() - }; - - let ct = aesgcm.call_method1( - pyo3::intern!(py, "encrypt"), - ( - pyo3::types::PyBytes::new(py, &base_nonce), - pyo3::types::PyBytes::new(py, plaintext.as_bytes()), - aad_arg, - ), - )?; - let ct_bytes = ct.extract::<&[u8]>()?; + let (key, base_nonce) = self.key_schedule(py, shared_secret.as_bytes(), info_bytes)?; + + let aesgcm = AesGcm::new(py, pyo3::types::PyBytes::new(py, &key).unbind().into_any())?; + let ct = aesgcm.encrypt(py, CffiBuf::from_bytes(py, &base_nonce), plaintext, aad)?; - // Combine enc + ct + let enc_bytes = enc.as_bytes(); + let ct_bytes = ct.as_bytes(); Ok(pyo3::types::PyBytes::new_with( py, - enc.len() + ct_bytes.len(), + enc_bytes.len() + ct_bytes.len(), |buf| { - buf[..enc.len()].copy_from_slice(&enc); - buf[enc.len()..].copy_from_slice(ct_bytes); + buf[..enc_bytes.len()].copy_from_slice(enc_bytes); + buf[enc_bytes.len()..].copy_from_slice(ct_bytes); Ok(()) }, )?) @@ -379,37 +338,19 @@ impl Suite { } let info_bytes: &[u8] = info.as_ref().map(|b| b.as_bytes()).unwrap_or(b""); - let aad_bytes: &[u8] = aad.as_ref().map(|b| b.as_bytes()).unwrap_or(b""); let (enc, ct) = ct_bytes.split_at(kem_params::X25519_NENC); let shared_secret = self.decap(py, enc, private_key)?; - let (key, base_nonce) = self.key_schedule(py, &shared_secret, info_bytes)?; - - // Create AESGCM with the derived key - let aesgcm = types::AESGCM - .get(py)? - .call1((pyo3::types::PyBytes::new(py, &key),))?; - - // Decrypt using AESGCM - let aad_arg = if aad_bytes.is_empty() { - py.None().into_bound(py) - } else { - pyo3::types::PyBytes::new(py, aad_bytes).into_any() - }; - - let pt = aesgcm - .call_method1( - pyo3::intern!(py, "decrypt"), - ( - pyo3::types::PyBytes::new(py, &base_nonce), - pyo3::types::PyBytes::new(py, ct), - aad_arg, - ), - ) - .map_err(|_| exceptions::InvalidTag::new_err(()))?; - - Ok(pt.extract::>()?) + let (key, base_nonce) = self.key_schedule(py, shared_secret.as_bytes(), info_bytes)?; + + let aesgcm = AesGcm::new(py, pyo3::types::PyBytes::new(py, &key).unbind().into_any())?; + aesgcm.decrypt( + py, + CffiBuf::from_bytes(py, &base_nonce), + CffiBuf::from_bytes(py, ct), + aad, + ) } } diff --git a/src/rust/src/backend/kdf.rs b/src/rust/src/backend/kdf.rs index e3f168225ed3..40aaf0834521 100644 --- a/src/rust/src/backend/kdf.rs +++ b/src/rust/src/backend/kdf.rs @@ -957,20 +957,15 @@ struct Hkdf { pub(crate) fn hkdf_extract( py: pyo3::Python<'_>, algorithm: &pyo3::Py, - salt: Option<&pyo3::Py>, + salt: Option<&[u8]>, key_material: &CffiBuf<'_>, ) -> CryptographyResult { let algorithm_bound = algorithm.bind(py); let digest_size = algorithm_bound .getattr(pyo3::intern!(py, "digest_size"))? .extract::()?; - let salt_bound = salt.map(|s| s.bind(py)); let default_salt = vec![0; digest_size]; - let salt_bytes: &[u8] = if let Some(bound) = salt_bound { - bound.as_bytes() - } else { - &default_salt - }; + let salt_bytes = salt.unwrap_or(&default_salt); let mut hmac = Hmac::new_bytes(py, salt_bytes, algorithm_bound)?; hmac.update_bytes(key_material.as_bytes())?; @@ -999,7 +994,12 @@ impl Hkdf { } let buf = CffiBuf::from_bytes(py, key_material); - let prk = hkdf_extract(py, &self.algorithm, self.salt.as_ref(), &buf)?; + let prk = hkdf_extract( + py, + &self.algorithm, + self.salt.as_ref().map(|s| s.bind(py)).map(|s| s.as_bytes()), + &buf, + )?; let mut hkdf_expand = HkdfExpand::new( py, self.algorithm.clone_ref(py), @@ -1056,10 +1056,10 @@ impl Hkdf { fn extract<'p>( py: pyo3::Python<'p>, algorithm: pyo3::Py, - salt: Option>, + salt: Option<&[u8]>, key_material: CffiBuf<'_>, ) -> CryptographyResult> { - let prk = hkdf_extract(py, &algorithm, salt.as_ref(), &key_material)?; + let prk = hkdf_extract(py, &algorithm, salt, &key_material)?; Ok(pyo3::types::PyBytes::new(py, &prk)) } @@ -1068,7 +1068,12 @@ impl Hkdf { py: pyo3::Python<'p>, key_material: CffiBuf<'_>, ) -> CryptographyResult> { - let prk = hkdf_extract(py, &self.algorithm, self.salt.as_ref(), &key_material)?; + let prk = hkdf_extract( + py, + &self.algorithm, + self.salt.as_ref().map(|s| s.bind(py)).map(|s| s.as_bytes()), + &key_material, + )?; Ok(pyo3::types::PyBytes::new(py, &prk)) } @@ -1118,7 +1123,7 @@ impl Hkdf { name = "HKDFExpand" )] // NO-COVERAGE-END -struct HkdfExpand { +pub(crate) struct HkdfExpand { algorithm: pyo3::Py, info: pyo3::Py, length: usize, @@ -1181,7 +1186,7 @@ impl HkdfExpand { impl HkdfExpand { #[new] #[pyo3(signature = (algorithm, length, info, backend=None))] - fn new( + pub(crate) fn new( py: pyo3::Python<'_>, algorithm: pyo3::Py, length: usize, @@ -1231,7 +1236,7 @@ impl HkdfExpand { self.derive_into_buffer(py, key_material.as_bytes(), buf.as_mut_bytes()) } - fn derive<'p>( + pub(crate) fn derive<'p>( &mut self, py: pyo3::Python<'p>, key_material: CffiBuf<'_>, diff --git a/src/rust/src/backend/x25519.rs b/src/rust/src/backend/x25519.rs index 9ee092725aec..fac7a9d46ffc 100644 --- a/src/rust/src/backend/x25519.rs +++ b/src/rust/src/backend/x25519.rs @@ -17,7 +17,7 @@ pub(crate) struct X25519PublicKey { } #[pyo3::pyfunction] -fn generate_key() -> CryptographyResult { +pub(crate) fn generate_key() -> CryptographyResult { Ok(X25519PrivateKey { pkey: openssl::pkey::PKey::generate_x25519()?, }) @@ -52,7 +52,7 @@ fn from_private_bytes(data: CffiBuf<'_>) -> pyo3::PyResult { } #[pyo3::pyfunction] -fn from_public_bytes(data: &[u8]) -> pyo3::PyResult { +pub(crate) fn from_public_bytes(data: &[u8]) -> pyo3::PyResult { let pkey = openssl::pkey::PKey::public_key_from_raw_bytes(data, openssl::pkey::Id::X25519) .map_err(|_| { pyo3::exceptions::PyValueError::new_err("An X25519 public key is 32 bytes long") diff --git a/src/rust/src/types.rs b/src/rust/src/types.rs index 8a128c6c1da9..3a7b02c6d69e 100644 --- a/src/rust/src/types.rs +++ b/src/rust/src/types.rs @@ -486,15 +486,6 @@ pub static ED448_PUBLIC_KEY: LazyPyImport = LazyPyImport::new( &["Ed448PublicKey"], ); -pub static X25519_PRIVATE_KEY: LazyPyImport = LazyPyImport::new( - "cryptography.hazmat.primitives.asymmetric.x25519", - &["X25519PrivateKey"], -); -pub static X25519_PUBLIC_KEY: LazyPyImport = LazyPyImport::new( - "cryptography.hazmat.primitives.asymmetric.x25519", - &["X25519PublicKey"], -); - pub static DSA_PRIVATE_KEY: LazyPyImport = LazyPyImport::new( "cryptography.hazmat.primitives.asymmetric.dsa", &["DSAPrivateKey"], @@ -622,12 +613,6 @@ pub static KBKDF_COUNTER_LOCATION: LazyPyImport = LazyPyImport::new( &["CounterLocation"], ); -pub static HKDF_EXPAND: LazyPyImport = - LazyPyImport::new("cryptography.hazmat.primitives.kdf.hkdf", &["HKDFExpand"]); - -pub static AESGCM: LazyPyImport = - LazyPyImport::new("cryptography.hazmat.primitives.ciphers.aead", &["AESGCM"]); - #[cfg(test)] mod tests { use super::LazyPyImport; From cae4b8c2a2c2e6a0222cef13e14eb0f96dc5e250 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Wed, 28 Jan 2026 19:57:00 -0500 Subject: [PATCH 6/7] test case --- tests/hazmat/primitives/test_hpke.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/hazmat/primitives/test_hpke.py b/tests/hazmat/primitives/test_hpke.py index 794121c4d9c7..9b0a1d204b1d 100644 --- a/tests/hazmat/primitives/test_hpke.py +++ b/tests/hazmat/primitives/test_hpke.py @@ -170,6 +170,8 @@ def test_truncated_ciphertext(self): with pytest.raises(InvalidTag): suite.decrypt(truncated, sk_r) + with pytest.raises(InvalidTag): + suite.decrypt(b"\x00") def test_vector_decryption(self, subtests): vectors = load_vectors_from_file( From e60cb24704a517b0b754015a1a3ee7c9cd7f4f13 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Wed, 28 Jan 2026 20:09:16 -0500 Subject: [PATCH 7/7] whoops --- tests/hazmat/primitives/test_hpke.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/hazmat/primitives/test_hpke.py b/tests/hazmat/primitives/test_hpke.py index 9b0a1d204b1d..39e6176f9833 100644 --- a/tests/hazmat/primitives/test_hpke.py +++ b/tests/hazmat/primitives/test_hpke.py @@ -171,7 +171,7 @@ def test_truncated_ciphertext(self): with pytest.raises(InvalidTag): suite.decrypt(truncated, sk_r) with pytest.raises(InvalidTag): - suite.decrypt(b"\x00") + suite.decrypt(b"\x00", sk_r) def test_vector_decryption(self, subtests): vectors = load_vectors_from_file(