Skip to content
2 changes: 1 addition & 1 deletion anchor/common/bls_lagrange/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ zeroize = { workspace = true }
blst = { workspace = true }

[features]
default = ["blsful"]
default = ["blst_single_thread"]
blsful = ["dep:blstrs_plus", "dep:vsss-rs"]
blst = ["dep:blst"]
blst_single_thread = ["blst"]
47 changes: 29 additions & 18 deletions anchor/common/bls_lagrange/src/blsful.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use std::num::NonZeroU64;
use blstrs_plus::{G2Projective, Scalar};
use rand::{CryptoRng, Rng};
use vsss_rs::{
shamir, IdentifierPrimeField, ParticipantIdGeneratorType, ReadableShareSet, ValueGroup,
elliptic_curve::Field, shamir, IdentifierPrimeField, ParticipantIdGeneratorType,
ReadableShareSet, ValueGroup,
};
use zeroize::Zeroizing;

Expand Down Expand Up @@ -56,24 +57,21 @@ pub fn split_with_rng(
.try_into()
.map_err(|_| Error::InternalError)?,
);
let key = if result.is_some().into() {
Zeroizing::new(IdentifierPrimeField(result.unwrap()))
} else {
return Err(Error::InternalError);
};
let scalar = result.into_option().ok_or(Error::InternalError)?;
if bool::from(scalar.is_zero()) {
return Err(Error::ZeroKey);
}
let key = Zeroizing::new(IdentifierPrimeField(scalar));

let ids = ids.into_iter().map(|k| k.identifier).collect::<Vec<_>>();

let result = Zeroizing::new(
shamir::split_secret_with_participant_generator(
threshold as usize,
ids.len(),
&*key,
rng,
&[ParticipantIdGeneratorType::List { list: &ids }],
)
.map_err(|_| Error::InternalError)?,
);
let result = Zeroizing::new(shamir::split_secret_with_participant_generator(
threshold as usize,
ids.len(),
&*key,
rng,
&[ParticipantIdGeneratorType::List { list: &ids }],
)?);

result
.iter()
Expand Down Expand Up @@ -111,7 +109,7 @@ pub fn combine_signatures(
.zip(ids)
.map(|(sig, id)| {
let Some(bytes) = sig.serialize_uncompressed() else {
return Err(Error::InternalError);
return Err(Error::InvalidSignature);
};
let g2 = G2Projective::from_uncompressed(&bytes);
if g2.is_some().into() {
Expand All @@ -122,7 +120,20 @@ pub fn combine_signatures(
})
.collect::<Result<Vec<_>, _>>()?;

let result = share_set.combine().map_err(|_| Error::InternalError)?;
let result = share_set.combine()?;
bls::Signature::deserialize_uncompressed(&result.0.to_uncompressed())
.map_err(|_| Error::InternalError)
}

impl From<vsss_rs::Error> for Error {
fn from(value: vsss_rs::Error) -> Self {
match value {
vsss_rs::Error::SharingMinThreshold => Error::InvalidThreshold,
vsss_rs::Error::SharingLimitLessThanThreshold => Error::InvalidThreshold,
vsss_rs::Error::SharingInvalidIdentifier => Error::ZeroId,
vsss_rs::Error::SharingDuplicateIdentifier => Error::RepeatedId,
vsss_rs::Error::InvalidSecret => Error::ZeroKey,
_ => Error::InternalError,
}
}
}
69 changes: 45 additions & 24 deletions anchor/common/bls_lagrange/src/blst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::{
iter::{once, repeat_with},
mem,
num::NonZeroU64,
sync::LazyLock,
};

use bls::Signature;
Expand All @@ -12,18 +11,6 @@ use rand::prelude::*;

use crate::{random_key, Error};

static WARNING: LazyLock<()> = LazyLock::new(|| {
eprintln!(
r#"
#######################################################################################
### YOU ARE USING AN UNAUDITED, UNSAFE IMPLEMENTATION OF BLS LAGRANGE INTERPOLATION ###
### ###
### !!! DO NOT USE IN PRODUCTION !!! ###
#######################################################################################
"#
)
});

#[derive(Debug, Clone)]
pub struct KeyId {
num: u64,
Expand All @@ -40,7 +27,8 @@ impl TryFrom<u64> for KeyId {
if value != 0 {
unsafe {
let mut id = blst_scalar::default();
blst_scalar_from_uint64(&mut id, &value);
let value_le_bytes = value.to_le_bytes();
blst_scalar_from_le_bytes(&mut id, value_le_bytes.as_ptr(), 8);
Ok(KeyId {
num: value,
scalar: id,
Expand All @@ -55,7 +43,8 @@ impl From<NonZeroU64> for KeyId {
fn from(value: NonZeroU64) -> Self {
unsafe {
let mut id = blst_scalar::default();
blst_scalar_from_uint64(&mut id, &value.get());
let value_le_bytes = value.get().to_le_bytes();
blst_scalar_from_le_bytes(&mut id, value_le_bytes.as_ptr(), 8);
KeyId {
num: value.get(),
scalar: id,
Expand All @@ -76,28 +65,33 @@ pub fn split_with_rng(
ids: impl IntoIterator<Item = KeyId>,
rng: &mut (impl CryptoRng + Rng),
) -> Result<Vec<(KeyId, bls::SecretKey)>, Error> {
LazyLock::force(&WARNING);
if threshold <= 1 {
return Err(Error::InvalidThreshold);
}

// `bls::SecretKey` contains a blst `SecretKey`, which zeroizes on drop.
// These are the random coefficients for our polynomial.
let keys = repeat_with(|| random_key(rng))
let random_coefficients = repeat_with(|| random_key(rng))
.take((threshold - 1) as usize)
.collect::<Result<Vec<_>, _>>()?;

// This will always have len == threshold, so it's non-empty
let msk = once(key)
.chain(keys.iter())
let coefficients = once(key)
.chain(random_coefficients.iter())
.map(|key| <&blst_scalar>::from(key.point()))
.collect::<Vec<_>>();

unsafe {
if !blst_sk_check(coefficients[0]) {
return Err(Error::ZeroKey);
}
}

ids.into_iter()
.map(|id| unsafe {
// Compute f(id), which is the secret for the participant with that id.

let mut y = (*msk.last().expect("msk is non-empty")).clone();
let mut y = (*coefficients.last().expect("coefficients is non-empty")).clone();
// As threshold is 2 or greater, this will do at least one iteration.
// At the beginning of the first iteration, y is the coefficient of x^threshold.
// We multiply it by x (=id), and add the coefficient of x^(threshold - 1), until we add
Expand All @@ -111,7 +105,7 @@ pub fn split_with_rng(
if !blst_sk_mul_n_check(&mut y, &y, &id.scalar) {
return Err(Error::ZeroId);
}
assert!(blst_sk_add_n_check(&mut y, &y, msk[i as usize]));
assert!(blst_sk_add_n_check(&mut y, &y, coefficients[i as usize]));
}
// SecretKey is repr(transparent), so the transmute is fine.
// We pass a reference, and afterward, the SecretKey is dropped, zeroizing it.
Expand All @@ -124,7 +118,6 @@ pub fn split_with_rng(
}

pub fn combine_signatures(signatures: &[Signature], ids: &[KeyId]) -> Result<Signature, Error> {
LazyLock::force(&WARNING);
if signatures.len() < 2 {
return Err(Error::LessThanTwoSignatures);
}
Expand Down Expand Up @@ -210,8 +203,7 @@ fn mult(signatures: &[min_pk::Signature], d: &[u8]) -> min_pk::Signature {
let p: [*const blst_p2_affine; 2] = [<&blst_p2_affine>::from(&signatures[0]), std::ptr::null()];
let s: [*const u8; 2] = [&d[0], std::ptr::null()];
unsafe {
let mut scratch: Vec<u64> =
Vec::with_capacity(blst_p2s_mult_pippenger_scratch_sizeof(signatures.len()) / 8);
let mut scratch = vec![0; blst_p2s_mult_pippenger_scratch_sizeof(signatures.len()) / 8];
blst_p2s_mult_pippenger(
&mut ret,
&p[0],
Expand All @@ -224,3 +216,32 @@ fn mult(signatures: &[min_pk::Signature], d: &[u8]) -> min_pk::Signature {
}
ret_affine.into()
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_key_id_from_u64() {
let mut scalar = blst_scalar::default();

let mut arr = [0u8; 128];
StdRng::seed_from_u64(0x1234565EED << 11).fill_bytes(&mut arr);

for i in 0..(arr.len() - 8) {
assert_eq!(
// passing the u64 by value...
&crate::blst::KeyId::try_from(u64::from_le_bytes(
arr[i..i + 8].try_into().unwrap()
))
.unwrap()
.scalar,
// ...should return the same as pointing to our array
unsafe {
blst_scalar_from_le_bytes(&mut scalar, &arr[i], 8);
&scalar
}
);
}
}
}
129 changes: 127 additions & 2 deletions anchor/common/bls_lagrange/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub enum Error {
LessThanTwoSignatures,
NotOneIdPerSignature,
ZeroId,
ZeroKey,
RepeatedId,
InvalidSignature,
}
Expand All @@ -42,9 +43,10 @@ pub(crate) fn random_key(rng: &mut (impl CryptoRng + Rng)) -> Result<SecretKey,

#[cfg(test)]
mod tests {
use std::{hint::black_box, time::Instant};
use std::{hint::black_box, mem, time::Instant};

use bls::Hash256;
use ::blst::{blst_scalar, blst_scalar_from_le_bytes};
use bls::{Hash256, Signature};

use super::*;

Expand Down Expand Up @@ -145,4 +147,127 @@ mod tests {
timing.elapsed().as_millis()
);
}

#[test]
fn test_invalid_threshold() {
let rng = &mut StdRng::seed_from_u64(0x12345EED00000123);
let key = random_key(rng).unwrap();
assert!(matches!(
split_with_rng(
&key,
1,
(1..=10).map(|x| KeyId::try_from(x as u64).unwrap()),
rng
),
Err(Error::InvalidThreshold)
));
assert!(matches!(
split_with_rng(
&key,
0,
(144..=166).map(|x| KeyId::try_from(x as u64).unwrap()),
rng
),
Err(Error::InvalidThreshold)
));
}

#[test]
fn test_less_than_two_sigs() {
let signature = [Signature::empty()];
let key_id = [KeyId::try_from(97).unwrap()];
assert!(matches!(
combine_signatures(&signature, &key_id),
Err(Error::LessThanTwoSignatures)
));
}

#[test]
fn test_not_one_id_per_signature() {
let signatures = [Signature::empty(), Signature::infinity().unwrap()];
let key_ids = [KeyId::try_from(4).unwrap()];
assert!(matches!(
combine_signatures(&signatures, &key_ids),
Err(Error::NotOneIdPerSignature)
));
let signatures = [Signature::infinity().unwrap(), Signature::empty()];
let key_ids = [
KeyId::try_from(2).unwrap(),
KeyId::try_from(1).unwrap(),
KeyId::try_from(4).unwrap(),
];
assert!(matches!(
combine_signatures(&signatures, &key_ids),
Err(Error::NotOneIdPerSignature)
));
}

#[test]
fn test_zero_id() {
assert!(matches!(KeyId::try_from(0), Err(Error::ZeroId)));
}

#[test]
fn test_zero_key() {
let rng = &mut StdRng::seed_from_u64(0x12345EED55500000);
// it's not easy to get a zero key in the first place...
let key = SecretKey::from_point(unsafe {
let mut scalar = blst_scalar::default();
blst_scalar_from_le_bytes(&mut scalar, &0u8, 1);
&mem::transmute::<blst_scalar, ::blst::min_pk::SecretKey>(scalar)
});
assert!(matches!(
split_with_rng(
&key,
3,
(1..=10).map(|x| KeyId::try_from(x as u64).unwrap()),
rng
),
Err(Error::ZeroKey)
));
}

#[test]
fn test_repeated_id() {
let rng = &mut StdRng::seed_from_u64(0xF2345EED0000000);
let master = random_key(rng).unwrap();
let keys = split_with_rng(
&master,
3,
(11..=15).map(|x| KeyId::try_from(x as u64).unwrap()),
rng,
)
.unwrap();

let (ids, keys): (Vec<_>, Vec<_>) = keys.into_iter().unzip();

let mut data = [0u8; 32];
rng.fill(&mut data);

let signers = [0, 1, 1];

let signatures = signers
.iter()
.map(|signer| keys[*signer].sign(Hash256::from(data)))
.collect::<Vec<_>>();
let ids = signers
.into_iter()
.map(|signer| ids[signer].clone())
.collect::<Vec<_>>();

assert!(matches!(
combine_signatures(&signatures, &ids),
Err(Error::RepeatedId)
));
}

#[test]
fn test_invalid_signature() {
let signature = [Signature::empty(), Signature::empty()];
let key_id = [KeyId::try_from(99).unwrap(), KeyId::try_from(98).unwrap()];
assert!(matches!(
combine_signatures(&signature, &key_id),
Err(Error::InvalidSignature)
));
}
}