From 4db96c23b679fca0f29a2ba3d8fb15c27ae6075f Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 30 Sep 2025 21:56:39 +0000 Subject: [PATCH 1/4] feat: add typed field hasher interface in MiMC package --- ecc/bls12-377/fr/mimc/mimc.go | 24 ++++++++++++++-- ecc/bls12-377/fr/mimc/mimc_test.go | 27 ++++++++++++++++++ ecc/bls12-381/fr/mimc/mimc.go | 24 ++++++++++++++-- ecc/bls12-381/fr/mimc/mimc_test.go | 27 ++++++++++++++++++ ecc/bls24-315/fr/mimc/mimc.go | 24 ++++++++++++++-- ecc/bls24-315/fr/mimc/mimc_test.go | 27 ++++++++++++++++++ ecc/bls24-317/fr/mimc/mimc.go | 24 ++++++++++++++-- ecc/bls24-317/fr/mimc/mimc_test.go | 27 ++++++++++++++++++ ecc/bn254/fr/mimc/mimc.go | 24 ++++++++++++++-- ecc/bn254/fr/mimc/mimc_test.go | 27 ++++++++++++++++++ ecc/bw6-633/fr/mimc/mimc.go | 24 ++++++++++++++-- ecc/bw6-633/fr/mimc/mimc_test.go | 27 ++++++++++++++++++ ecc/bw6-761/fr/mimc/mimc.go | 24 ++++++++++++++-- ecc/bw6-761/fr/mimc/mimc_test.go | 27 ++++++++++++++++++ ecc/grumpkin/fr/mimc/mimc.go | 24 ++++++++++++++-- ecc/grumpkin/fr/mimc/mimc_test.go | 27 ++++++++++++++++++ .../crypto/hash/mimc/template/mimc.go.tmpl | 24 ++++++++++++++-- .../mimc/template/tests/mimc_test.go.tmpl | 28 ++++++++++++++++++- 18 files changed, 441 insertions(+), 19 deletions(-) diff --git a/ecc/bls12-377/fr/mimc/mimc.go b/ecc/bls12-377/fr/mimc/mimc.go index 4c4aec6c6..60f6660cf 100644 --- a/ecc/bls12-377/fr/mimc/mimc.go +++ b/ecc/bls12-377/fr/mimc/mimc.go @@ -17,6 +17,14 @@ import ( "golang.org/x/crypto/sha3" ) +type FieldHasher interface { + hash.StateStorer + WriteElement(e fr.Element) + SumElement() fr.Element +} + +var _ FieldHasher = NewMiMC() + func init() { hash.RegisterHash(hash.MIMC_BLS12_377, func() stdhash.Hash { return NewMiMC() @@ -54,7 +62,7 @@ func GetConstants() []big.Int { } // NewMiMC returns a MiMC implementation, pure Go reference implementation. -func NewMiMC(opts ...Option) hash.StateStorer { +func NewMiMC(opts ...Option) FieldHasher { d := new(digest) d.Reset() cfg := mimcOptions(opts...) @@ -72,12 +80,19 @@ func (d *digest) Reset() { // It does not change the underlying hash state. func (d *digest) Sum(b []byte) []byte { buffer := d.checksum() - d.data = nil // flush the data already hashed + d.data = d.data[:0] hash := buffer.Bytes() b = append(b, hash[:]...) return b } +// SumElement returns the current hash as a field element. +func (d *digest) SumElement() fr.Element { + r := d.checksum() + d.data = d.data[:0] + return r +} + // BlockSize returns the hash's underlying block size. // The Write method must be able to accept any amount // of data, but it may operate more efficiently if all writes @@ -124,6 +139,11 @@ func (d *digest) Write(p []byte) (int, error) { return len(p), nil } +// WriteElement adds a field element to the running hash. +func (d *digest) WriteElement(e fr.Element) { + d.data = append(d.data, e) +} + // Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form diff --git a/ecc/bls12-377/fr/mimc/mimc_test.go b/ecc/bls12-377/fr/mimc/mimc_test.go index 6f076df56..6928ae78f 100644 --- a/ecc/bls12-377/fr/mimc/mimc_test.go +++ b/ecc/bls12-377/fr/mimc/mimc_test.go @@ -116,3 +116,30 @@ func TestSetState(t *testing.T) { } } } + +func TestFieldHasher(t *testing.T) { + assert := require.New(t) + + h1 := mimc.NewMiMC() + h2 := mimc.NewMiMC() + h3 := mimc.NewMiMC() + randInputs := make(fr.Vector, 10) + randInputs.MustSetRandom() + + for i := range randInputs { + h1.Write(randInputs[i].Marshal()) + h2.WriteElement(randInputs[i]) + } + dgst1 := h1.Sum(nil) + dgst2 := h2.Sum(nil) + assert.Equal(dgst1, dgst2, "hashes do not match") + + // test SumElement + h3.WriteElement(randInputs[0]) + for i := 1; i < len(randInputs); i++ { + h3.Write(randInputs[i].Marshal()) + } + _dgst3 := h3.SumElement() + dgst3 := _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") +} diff --git a/ecc/bls12-381/fr/mimc/mimc.go b/ecc/bls12-381/fr/mimc/mimc.go index bedaf2d9e..71d8b8abb 100644 --- a/ecc/bls12-381/fr/mimc/mimc.go +++ b/ecc/bls12-381/fr/mimc/mimc.go @@ -17,6 +17,14 @@ import ( "golang.org/x/crypto/sha3" ) +type FieldHasher interface { + hash.StateStorer + WriteElement(e fr.Element) + SumElement() fr.Element +} + +var _ FieldHasher = NewMiMC() + func init() { hash.RegisterHash(hash.MIMC_BLS12_381, func() stdhash.Hash { return NewMiMC() @@ -54,7 +62,7 @@ func GetConstants() []big.Int { } // NewMiMC returns a MiMC implementation, pure Go reference implementation. -func NewMiMC(opts ...Option) hash.StateStorer { +func NewMiMC(opts ...Option) FieldHasher { d := new(digest) d.Reset() cfg := mimcOptions(opts...) @@ -72,12 +80,19 @@ func (d *digest) Reset() { // It does not change the underlying hash state. func (d *digest) Sum(b []byte) []byte { buffer := d.checksum() - d.data = nil // flush the data already hashed + d.data = d.data[:0] hash := buffer.Bytes() b = append(b, hash[:]...) return b } +// SumElement returns the current hash as a field element. +func (d *digest) SumElement() fr.Element { + r := d.checksum() + d.data = d.data[:0] + return r +} + // BlockSize returns the hash's underlying block size. // The Write method must be able to accept any amount // of data, but it may operate more efficiently if all writes @@ -124,6 +139,11 @@ func (d *digest) Write(p []byte) (int, error) { return len(p), nil } +// WriteElement adds a field element to the running hash. +func (d *digest) WriteElement(e fr.Element) { + d.data = append(d.data, e) +} + // Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form diff --git a/ecc/bls12-381/fr/mimc/mimc_test.go b/ecc/bls12-381/fr/mimc/mimc_test.go index 6f497c36e..98ed618ec 100644 --- a/ecc/bls12-381/fr/mimc/mimc_test.go +++ b/ecc/bls12-381/fr/mimc/mimc_test.go @@ -116,3 +116,30 @@ func TestSetState(t *testing.T) { } } } + +func TestFieldHasher(t *testing.T) { + assert := require.New(t) + + h1 := mimc.NewMiMC() + h2 := mimc.NewMiMC() + h3 := mimc.NewMiMC() + randInputs := make(fr.Vector, 10) + randInputs.MustSetRandom() + + for i := range randInputs { + h1.Write(randInputs[i].Marshal()) + h2.WriteElement(randInputs[i]) + } + dgst1 := h1.Sum(nil) + dgst2 := h2.Sum(nil) + assert.Equal(dgst1, dgst2, "hashes do not match") + + // test SumElement + h3.WriteElement(randInputs[0]) + for i := 1; i < len(randInputs); i++ { + h3.Write(randInputs[i].Marshal()) + } + _dgst3 := h3.SumElement() + dgst3 := _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") +} diff --git a/ecc/bls24-315/fr/mimc/mimc.go b/ecc/bls24-315/fr/mimc/mimc.go index 35b38a584..833e90cb8 100644 --- a/ecc/bls24-315/fr/mimc/mimc.go +++ b/ecc/bls24-315/fr/mimc/mimc.go @@ -17,6 +17,14 @@ import ( "golang.org/x/crypto/sha3" ) +type FieldHasher interface { + hash.StateStorer + WriteElement(e fr.Element) + SumElement() fr.Element +} + +var _ FieldHasher = NewMiMC() + func init() { hash.RegisterHash(hash.MIMC_BLS24_315, func() stdhash.Hash { return NewMiMC() @@ -54,7 +62,7 @@ func GetConstants() []big.Int { } // NewMiMC returns a MiMC implementation, pure Go reference implementation. -func NewMiMC(opts ...Option) hash.StateStorer { +func NewMiMC(opts ...Option) FieldHasher { d := new(digest) d.Reset() cfg := mimcOptions(opts...) @@ -72,12 +80,19 @@ func (d *digest) Reset() { // It does not change the underlying hash state. func (d *digest) Sum(b []byte) []byte { buffer := d.checksum() - d.data = nil // flush the data already hashed + d.data = d.data[:0] hash := buffer.Bytes() b = append(b, hash[:]...) return b } +// SumElement returns the current hash as a field element. +func (d *digest) SumElement() fr.Element { + r := d.checksum() + d.data = d.data[:0] + return r +} + // BlockSize returns the hash's underlying block size. // The Write method must be able to accept any amount // of data, but it may operate more efficiently if all writes @@ -124,6 +139,11 @@ func (d *digest) Write(p []byte) (int, error) { return len(p), nil } +// WriteElement adds a field element to the running hash. +func (d *digest) WriteElement(e fr.Element) { + d.data = append(d.data, e) +} + // Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form diff --git a/ecc/bls24-315/fr/mimc/mimc_test.go b/ecc/bls24-315/fr/mimc/mimc_test.go index 2f901b07e..95a8d9673 100644 --- a/ecc/bls24-315/fr/mimc/mimc_test.go +++ b/ecc/bls24-315/fr/mimc/mimc_test.go @@ -116,3 +116,30 @@ func TestSetState(t *testing.T) { } } } + +func TestFieldHasher(t *testing.T) { + assert := require.New(t) + + h1 := mimc.NewMiMC() + h2 := mimc.NewMiMC() + h3 := mimc.NewMiMC() + randInputs := make(fr.Vector, 10) + randInputs.MustSetRandom() + + for i := range randInputs { + h1.Write(randInputs[i].Marshal()) + h2.WriteElement(randInputs[i]) + } + dgst1 := h1.Sum(nil) + dgst2 := h2.Sum(nil) + assert.Equal(dgst1, dgst2, "hashes do not match") + + // test SumElement + h3.WriteElement(randInputs[0]) + for i := 1; i < len(randInputs); i++ { + h3.Write(randInputs[i].Marshal()) + } + _dgst3 := h3.SumElement() + dgst3 := _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") +} diff --git a/ecc/bls24-317/fr/mimc/mimc.go b/ecc/bls24-317/fr/mimc/mimc.go index 1807b6726..ef166a228 100644 --- a/ecc/bls24-317/fr/mimc/mimc.go +++ b/ecc/bls24-317/fr/mimc/mimc.go @@ -17,6 +17,14 @@ import ( "golang.org/x/crypto/sha3" ) +type FieldHasher interface { + hash.StateStorer + WriteElement(e fr.Element) + SumElement() fr.Element +} + +var _ FieldHasher = NewMiMC() + func init() { hash.RegisterHash(hash.MIMC_BLS24_317, func() stdhash.Hash { return NewMiMC() @@ -54,7 +62,7 @@ func GetConstants() []big.Int { } // NewMiMC returns a MiMC implementation, pure Go reference implementation. -func NewMiMC(opts ...Option) hash.StateStorer { +func NewMiMC(opts ...Option) FieldHasher { d := new(digest) d.Reset() cfg := mimcOptions(opts...) @@ -72,12 +80,19 @@ func (d *digest) Reset() { // It does not change the underlying hash state. func (d *digest) Sum(b []byte) []byte { buffer := d.checksum() - d.data = nil // flush the data already hashed + d.data = d.data[:0] hash := buffer.Bytes() b = append(b, hash[:]...) return b } +// SumElement returns the current hash as a field element. +func (d *digest) SumElement() fr.Element { + r := d.checksum() + d.data = d.data[:0] + return r +} + // BlockSize returns the hash's underlying block size. // The Write method must be able to accept any amount // of data, but it may operate more efficiently if all writes @@ -124,6 +139,11 @@ func (d *digest) Write(p []byte) (int, error) { return len(p), nil } +// WriteElement adds a field element to the running hash. +func (d *digest) WriteElement(e fr.Element) { + d.data = append(d.data, e) +} + // Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form diff --git a/ecc/bls24-317/fr/mimc/mimc_test.go b/ecc/bls24-317/fr/mimc/mimc_test.go index cc838f60c..0eba16c08 100644 --- a/ecc/bls24-317/fr/mimc/mimc_test.go +++ b/ecc/bls24-317/fr/mimc/mimc_test.go @@ -116,3 +116,30 @@ func TestSetState(t *testing.T) { } } } + +func TestFieldHasher(t *testing.T) { + assert := require.New(t) + + h1 := mimc.NewMiMC() + h2 := mimc.NewMiMC() + h3 := mimc.NewMiMC() + randInputs := make(fr.Vector, 10) + randInputs.MustSetRandom() + + for i := range randInputs { + h1.Write(randInputs[i].Marshal()) + h2.WriteElement(randInputs[i]) + } + dgst1 := h1.Sum(nil) + dgst2 := h2.Sum(nil) + assert.Equal(dgst1, dgst2, "hashes do not match") + + // test SumElement + h3.WriteElement(randInputs[0]) + for i := 1; i < len(randInputs); i++ { + h3.Write(randInputs[i].Marshal()) + } + _dgst3 := h3.SumElement() + dgst3 := _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") +} diff --git a/ecc/bn254/fr/mimc/mimc.go b/ecc/bn254/fr/mimc/mimc.go index 91f260e3c..8a20b0bbb 100644 --- a/ecc/bn254/fr/mimc/mimc.go +++ b/ecc/bn254/fr/mimc/mimc.go @@ -17,6 +17,14 @@ import ( "golang.org/x/crypto/sha3" ) +type FieldHasher interface { + hash.StateStorer + WriteElement(e fr.Element) + SumElement() fr.Element +} + +var _ FieldHasher = NewMiMC() + func init() { hash.RegisterHash(hash.MIMC_BN254, func() stdhash.Hash { return NewMiMC() @@ -54,7 +62,7 @@ func GetConstants() []big.Int { } // NewMiMC returns a MiMC implementation, pure Go reference implementation. -func NewMiMC(opts ...Option) hash.StateStorer { +func NewMiMC(opts ...Option) FieldHasher { d := new(digest) d.Reset() cfg := mimcOptions(opts...) @@ -72,12 +80,19 @@ func (d *digest) Reset() { // It does not change the underlying hash state. func (d *digest) Sum(b []byte) []byte { buffer := d.checksum() - d.data = nil // flush the data already hashed + d.data = d.data[:0] hash := buffer.Bytes() b = append(b, hash[:]...) return b } +// SumElement returns the current hash as a field element. +func (d *digest) SumElement() fr.Element { + r := d.checksum() + d.data = d.data[:0] + return r +} + // BlockSize returns the hash's underlying block size. // The Write method must be able to accept any amount // of data, but it may operate more efficiently if all writes @@ -124,6 +139,11 @@ func (d *digest) Write(p []byte) (int, error) { return len(p), nil } +// WriteElement adds a field element to the running hash. +func (d *digest) WriteElement(e fr.Element) { + d.data = append(d.data, e) +} + // Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form diff --git a/ecc/bn254/fr/mimc/mimc_test.go b/ecc/bn254/fr/mimc/mimc_test.go index aa9e21c10..3bc739b45 100644 --- a/ecc/bn254/fr/mimc/mimc_test.go +++ b/ecc/bn254/fr/mimc/mimc_test.go @@ -116,3 +116,30 @@ func TestSetState(t *testing.T) { } } } + +func TestFieldHasher(t *testing.T) { + assert := require.New(t) + + h1 := mimc.NewMiMC() + h2 := mimc.NewMiMC() + h3 := mimc.NewMiMC() + randInputs := make(fr.Vector, 10) + randInputs.MustSetRandom() + + for i := range randInputs { + h1.Write(randInputs[i].Marshal()) + h2.WriteElement(randInputs[i]) + } + dgst1 := h1.Sum(nil) + dgst2 := h2.Sum(nil) + assert.Equal(dgst1, dgst2, "hashes do not match") + + // test SumElement + h3.WriteElement(randInputs[0]) + for i := 1; i < len(randInputs); i++ { + h3.Write(randInputs[i].Marshal()) + } + _dgst3 := h3.SumElement() + dgst3 := _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") +} diff --git a/ecc/bw6-633/fr/mimc/mimc.go b/ecc/bw6-633/fr/mimc/mimc.go index 1245152b8..988c1b8c9 100644 --- a/ecc/bw6-633/fr/mimc/mimc.go +++ b/ecc/bw6-633/fr/mimc/mimc.go @@ -17,6 +17,14 @@ import ( "golang.org/x/crypto/sha3" ) +type FieldHasher interface { + hash.StateStorer + WriteElement(e fr.Element) + SumElement() fr.Element +} + +var _ FieldHasher = NewMiMC() + func init() { hash.RegisterHash(hash.MIMC_BW6_633, func() stdhash.Hash { return NewMiMC() @@ -54,7 +62,7 @@ func GetConstants() []big.Int { } // NewMiMC returns a MiMC implementation, pure Go reference implementation. -func NewMiMC(opts ...Option) hash.StateStorer { +func NewMiMC(opts ...Option) FieldHasher { d := new(digest) d.Reset() cfg := mimcOptions(opts...) @@ -72,12 +80,19 @@ func (d *digest) Reset() { // It does not change the underlying hash state. func (d *digest) Sum(b []byte) []byte { buffer := d.checksum() - d.data = nil // flush the data already hashed + d.data = d.data[:0] hash := buffer.Bytes() b = append(b, hash[:]...) return b } +// SumElement returns the current hash as a field element. +func (d *digest) SumElement() fr.Element { + r := d.checksum() + d.data = d.data[:0] + return r +} + // BlockSize returns the hash's underlying block size. // The Write method must be able to accept any amount // of data, but it may operate more efficiently if all writes @@ -124,6 +139,11 @@ func (d *digest) Write(p []byte) (int, error) { return len(p), nil } +// WriteElement adds a field element to the running hash. +func (d *digest) WriteElement(e fr.Element) { + d.data = append(d.data, e) +} + // Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form diff --git a/ecc/bw6-633/fr/mimc/mimc_test.go b/ecc/bw6-633/fr/mimc/mimc_test.go index 105bf45c2..19b5185c2 100644 --- a/ecc/bw6-633/fr/mimc/mimc_test.go +++ b/ecc/bw6-633/fr/mimc/mimc_test.go @@ -116,3 +116,30 @@ func TestSetState(t *testing.T) { } } } + +func TestFieldHasher(t *testing.T) { + assert := require.New(t) + + h1 := mimc.NewMiMC() + h2 := mimc.NewMiMC() + h3 := mimc.NewMiMC() + randInputs := make(fr.Vector, 10) + randInputs.MustSetRandom() + + for i := range randInputs { + h1.Write(randInputs[i].Marshal()) + h2.WriteElement(randInputs[i]) + } + dgst1 := h1.Sum(nil) + dgst2 := h2.Sum(nil) + assert.Equal(dgst1, dgst2, "hashes do not match") + + // test SumElement + h3.WriteElement(randInputs[0]) + for i := 1; i < len(randInputs); i++ { + h3.Write(randInputs[i].Marshal()) + } + _dgst3 := h3.SumElement() + dgst3 := _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") +} diff --git a/ecc/bw6-761/fr/mimc/mimc.go b/ecc/bw6-761/fr/mimc/mimc.go index d34199504..1bd0d3a98 100644 --- a/ecc/bw6-761/fr/mimc/mimc.go +++ b/ecc/bw6-761/fr/mimc/mimc.go @@ -17,6 +17,14 @@ import ( "golang.org/x/crypto/sha3" ) +type FieldHasher interface { + hash.StateStorer + WriteElement(e fr.Element) + SumElement() fr.Element +} + +var _ FieldHasher = NewMiMC() + func init() { hash.RegisterHash(hash.MIMC_BW6_761, func() stdhash.Hash { return NewMiMC() @@ -54,7 +62,7 @@ func GetConstants() []big.Int { } // NewMiMC returns a MiMC implementation, pure Go reference implementation. -func NewMiMC(opts ...Option) hash.StateStorer { +func NewMiMC(opts ...Option) FieldHasher { d := new(digest) d.Reset() cfg := mimcOptions(opts...) @@ -72,12 +80,19 @@ func (d *digest) Reset() { // It does not change the underlying hash state. func (d *digest) Sum(b []byte) []byte { buffer := d.checksum() - d.data = nil // flush the data already hashed + d.data = d.data[:0] hash := buffer.Bytes() b = append(b, hash[:]...) return b } +// SumElement returns the current hash as a field element. +func (d *digest) SumElement() fr.Element { + r := d.checksum() + d.data = d.data[:0] + return r +} + // BlockSize returns the hash's underlying block size. // The Write method must be able to accept any amount // of data, but it may operate more efficiently if all writes @@ -124,6 +139,11 @@ func (d *digest) Write(p []byte) (int, error) { return len(p), nil } +// WriteElement adds a field element to the running hash. +func (d *digest) WriteElement(e fr.Element) { + d.data = append(d.data, e) +} + // Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form diff --git a/ecc/bw6-761/fr/mimc/mimc_test.go b/ecc/bw6-761/fr/mimc/mimc_test.go index afb5d5e6f..e0e7b21c8 100644 --- a/ecc/bw6-761/fr/mimc/mimc_test.go +++ b/ecc/bw6-761/fr/mimc/mimc_test.go @@ -116,3 +116,30 @@ func TestSetState(t *testing.T) { } } } + +func TestFieldHasher(t *testing.T) { + assert := require.New(t) + + h1 := mimc.NewMiMC() + h2 := mimc.NewMiMC() + h3 := mimc.NewMiMC() + randInputs := make(fr.Vector, 10) + randInputs.MustSetRandom() + + for i := range randInputs { + h1.Write(randInputs[i].Marshal()) + h2.WriteElement(randInputs[i]) + } + dgst1 := h1.Sum(nil) + dgst2 := h2.Sum(nil) + assert.Equal(dgst1, dgst2, "hashes do not match") + + // test SumElement + h3.WriteElement(randInputs[0]) + for i := 1; i < len(randInputs); i++ { + h3.Write(randInputs[i].Marshal()) + } + _dgst3 := h3.SumElement() + dgst3 := _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") +} diff --git a/ecc/grumpkin/fr/mimc/mimc.go b/ecc/grumpkin/fr/mimc/mimc.go index a1fe21bba..05da818b3 100644 --- a/ecc/grumpkin/fr/mimc/mimc.go +++ b/ecc/grumpkin/fr/mimc/mimc.go @@ -17,6 +17,14 @@ import ( "golang.org/x/crypto/sha3" ) +type FieldHasher interface { + hash.StateStorer + WriteElement(e fr.Element) + SumElement() fr.Element +} + +var _ FieldHasher = NewMiMC() + func init() { hash.RegisterHash(hash.MIMC_GRUMPKIN, func() stdhash.Hash { return NewMiMC() @@ -54,7 +62,7 @@ func GetConstants() []big.Int { } // NewMiMC returns a MiMC implementation, pure Go reference implementation. -func NewMiMC(opts ...Option) hash.StateStorer { +func NewMiMC(opts ...Option) FieldHasher { d := new(digest) d.Reset() cfg := mimcOptions(opts...) @@ -72,12 +80,19 @@ func (d *digest) Reset() { // It does not change the underlying hash state. func (d *digest) Sum(b []byte) []byte { buffer := d.checksum() - d.data = nil // flush the data already hashed + d.data = d.data[:0] hash := buffer.Bytes() b = append(b, hash[:]...) return b } +// SumElement returns the current hash as a field element. +func (d *digest) SumElement() fr.Element { + r := d.checksum() + d.data = d.data[:0] + return r +} + // BlockSize returns the hash's underlying block size. // The Write method must be able to accept any amount // of data, but it may operate more efficiently if all writes @@ -124,6 +139,11 @@ func (d *digest) Write(p []byte) (int, error) { return len(p), nil } +// WriteElement adds a field element to the running hash. +func (d *digest) WriteElement(e fr.Element) { + d.data = append(d.data, e) +} + // Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form diff --git a/ecc/grumpkin/fr/mimc/mimc_test.go b/ecc/grumpkin/fr/mimc/mimc_test.go index 88f802a09..bc12e8c0f 100644 --- a/ecc/grumpkin/fr/mimc/mimc_test.go +++ b/ecc/grumpkin/fr/mimc/mimc_test.go @@ -116,3 +116,30 @@ func TestSetState(t *testing.T) { } } } + +func TestFieldHasher(t *testing.T) { + assert := require.New(t) + + h1 := mimc.NewMiMC() + h2 := mimc.NewMiMC() + h3 := mimc.NewMiMC() + randInputs := make(fr.Vector, 10) + randInputs.MustSetRandom() + + for i := range randInputs { + h1.Write(randInputs[i].Marshal()) + h2.WriteElement(randInputs[i]) + } + dgst1 := h1.Sum(nil) + dgst2 := h2.Sum(nil) + assert.Equal(dgst1, dgst2, "hashes do not match") + + // test SumElement + h3.WriteElement(randInputs[0]) + for i := 1; i < len(randInputs); i++ { + h3.Write(randInputs[i].Marshal()) + } + _dgst3 := h3.SumElement() + dgst3 := _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") +} diff --git a/internal/generator/crypto/hash/mimc/template/mimc.go.tmpl b/internal/generator/crypto/hash/mimc/template/mimc.go.tmpl index b1130e209..c96c50ad6 100644 --- a/internal/generator/crypto/hash/mimc/template/mimc.go.tmpl +++ b/internal/generator/crypto/hash/mimc/template/mimc.go.tmpl @@ -10,6 +10,13 @@ import ( "golang.org/x/crypto/sha3" ) +type FieldHasher interface { + hash.StateStorer + WriteElement(e fr.Element) + SumElement() fr.Element +} +var _ FieldHasher = NewMiMC() + func init() { hash.RegisterHash(hash.MIMC_{{ .EnumID }}, func() stdhash.Hash { return NewMiMC() @@ -63,7 +70,7 @@ func GetConstants() []big.Int { } // NewMiMC returns a MiMC implementation, pure Go reference implementation. -func NewMiMC(opts ...Option) hash.StateStorer { +func NewMiMC(opts ...Option) FieldHasher { d := new(digest) d.Reset() cfg := mimcOptions(opts...) @@ -81,12 +88,20 @@ func (d *digest) Reset() { // It does not change the underlying hash state. func (d *digest) Sum(b []byte) []byte { buffer := d.checksum() - d.data = nil // flush the data already hashed + d.data = d.data[:0] hash := buffer.Bytes() b = append(b, hash[:]...) return b } +// SumElement returns the current hash as a field element. +func (d *digest) SumElement() fr.Element { + r := d.checksum() + d.data = d.data[:0] + return r +} + + // BlockSize returns the hash's underlying block size. // The Write method must be able to accept any amount // of data, but it may operate more efficiently if all writes @@ -133,6 +148,11 @@ func (d *digest) Write(p []byte) (int, error) { return len(p), nil } +// WriteElement adds a field element to the running hash. +func (d *digest) WriteElement(e fr.Element) { + d.data = append(d.data, e) +} + // Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form diff --git a/internal/generator/crypto/hash/mimc/template/tests/mimc_test.go.tmpl b/internal/generator/crypto/hash/mimc/template/tests/mimc_test.go.tmpl index 23c448437..bde054711 100644 --- a/internal/generator/crypto/hash/mimc/template/tests/mimc_test.go.tmpl +++ b/internal/generator/crypto/hash/mimc/template/tests/mimc_test.go.tmpl @@ -84,7 +84,6 @@ func TestSetState(t *testing.T) { storedStates := make([][]byte, len(randInputs)) - for i := range randInputs { storedStates[i] = h1.State() @@ -110,3 +109,30 @@ func TestSetState(t *testing.T) { } } } + +func TestFieldHasher(t *testing.T) { + assert := require.New(t) + + h1 := mimc.NewMiMC() + h2 := mimc.NewMiMC() + h3 := mimc.NewMiMC() + randInputs := make(fr.Vector, 10) + randInputs.MustSetRandom() + + for i := range randInputs { + h1.Write(randInputs[i].Marshal()) + h2.WriteElement(randInputs[i]) + } + dgst1 := h1.Sum(nil) + dgst2 := h2.Sum(nil) + assert.Equal(dgst1, dgst2, "hashes do not match") + + // test SumElement + h3.WriteElement(randInputs[0]) + for i := 1; i < len(randInputs); i++ { + h3.Write(randInputs[i].Marshal()) + } + _dgst3 := h3.SumElement() + dgst3 := _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") +} \ No newline at end of file From 045ed2b14272d6546aabd5fc4da15f2c379d8124 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Wed, 1 Oct 2025 02:56:27 +0000 Subject: [PATCH 2/4] feat: update mimc with field hasher --- ecc/bls12-377/fr/mimc/mimc.go | 58 ++++++++++++++++++- ecc/bls12-377/fr/mimc/mimc_test.go | 13 ++++- ecc/bls12-381/fr/mimc/mimc.go | 58 ++++++++++++++++++- ecc/bls12-381/fr/mimc/mimc_test.go | 13 ++++- ecc/bls24-315/fr/mimc/mimc.go | 58 ++++++++++++++++++- ecc/bls24-315/fr/mimc/mimc_test.go | 13 ++++- ecc/bls24-317/fr/mimc/mimc.go | 58 ++++++++++++++++++- ecc/bls24-317/fr/mimc/mimc_test.go | 13 ++++- ecc/bn254/fr/mimc/mimc.go | 58 ++++++++++++++++++- ecc/bn254/fr/mimc/mimc_test.go | 13 ++++- ecc/bw6-633/fr/mimc/mimc.go | 58 ++++++++++++++++++- ecc/bw6-633/fr/mimc/mimc_test.go | 13 ++++- ecc/bw6-761/fr/mimc/mimc.go | 58 ++++++++++++++++++- ecc/bw6-761/fr/mimc/mimc_test.go | 13 ++++- ecc/grumpkin/fr/mimc/mimc.go | 58 ++++++++++++++++++- ecc/grumpkin/fr/mimc/mimc_test.go | 13 ++++- .../crypto/hash/mimc/template/mimc.go.tmpl | 55 +++++++++++++++++- .../mimc/template/tests/mimc_test.go.tmpl | 13 ++++- 18 files changed, 592 insertions(+), 44 deletions(-) diff --git a/ecc/bls12-377/fr/mimc/mimc.go b/ecc/bls12-377/fr/mimc/mimc.go index 60f6660cf..7b7297773 100644 --- a/ecc/bls12-377/fr/mimc/mimc.go +++ b/ecc/bls12-377/fr/mimc/mimc.go @@ -17,13 +17,30 @@ import ( "golang.org/x/crypto/sha3" ) +// FieldHasher is an interface for a hash function that operates on field elements type FieldHasher interface { hash.StateStorer + + // WriteElement adds a field element to the running hash. WriteElement(e fr.Element) + + // SumElement returns the current hash as a field element. SumElement() fr.Element -} -var _ FieldHasher = NewMiMC() + // SumElements returns the current hash as a field element, + // after hashing all the provided elements (in addition to the already hashed ones). + // This is a convenience method to avoid multiple calls to WriteElement + // followed by a call to SumElement. + // It is equivalent to: + // for _, e := range elems { + // h.WriteElement(e) + // } + // return h. SumElement() + // + // This avoids copying the elements into the data slice and + // is more efficient. + SumElements([]fr.Element) fr.Element +} func init() { hash.RegisterHash(hash.MIMC_BLS12_377, func() stdhash.Hash { @@ -61,6 +78,17 @@ func GetConstants() []big.Int { return res } +// NewFieldHasher returns a FieldHasher (works with typed field elements, not bytes) +func NewFieldHasher(opts ...Option) FieldHasher { + r := NewMiMC(opts...) + return r.(FieldHasher) +} + +// NewBinaryHasher returns a hash.StateStorer (works with bytes, not typed field elements) +func NewBinaryHasher(opts ...Option) hash.StateStorer { + return NewMiMC(opts...) +} + // NewMiMC returns a MiMC implementation, pure Go reference implementation. func NewMiMC(opts ...Option) FieldHasher { d := new(digest) @@ -164,6 +192,32 @@ func (d *digest) checksum() fr.Element { return d.h } +// SumElements returns the current hash as a field element, +// after hashing all the provided elements (in addition to the already hashed ones). +// This is a convenience method to avoid multiple calls to WriteElement +// followed by a call to SumElement. +// It is equivalent to: +// +// for _, e := range elems { +// h.WriteElement(e) +// } +// return h. SumElement() +// +// This avoids copying the elements into the data slice and +// is more efficient. +func (d *digest) SumElements(elems []fr.Element) fr.Element { + for i := range d.data { + r := d.encrypt(d.data[i]) + d.h.Add(&r, &d.h).Add(&d.h, &d.data[i]) + } + for i := range elems { + r := d.encrypt(elems[i]) + d.h.Add(&r, &d.h).Add(&d.h, &elems[i]) + } + d.data = d.data[:0] + return d.h +} + // plain execution of a mimc run // m: message // k: encryption key diff --git a/ecc/bls12-377/fr/mimc/mimc_test.go b/ecc/bls12-377/fr/mimc/mimc_test.go index 6928ae78f..d5972e58b 100644 --- a/ecc/bls12-377/fr/mimc/mimc_test.go +++ b/ecc/bls12-377/fr/mimc/mimc_test.go @@ -120,9 +120,9 @@ func TestSetState(t *testing.T) { func TestFieldHasher(t *testing.T) { assert := require.New(t) - h1 := mimc.NewMiMC() - h2 := mimc.NewMiMC() - h3 := mimc.NewMiMC() + h1 := mimc.NewFieldHasher() + h2 := mimc.NewFieldHasher() + h3 := mimc.NewFieldHasher() randInputs := make(fr.Vector, 10) randInputs.MustSetRandom() @@ -142,4 +142,11 @@ func TestFieldHasher(t *testing.T) { _dgst3 := h3.SumElement() dgst3 := _dgst3.Bytes() assert.Equal(dgst1, dgst3[:], "hashes do not match") + + // test SumElements + h3.Reset() + _dgst3 = h3.SumElements(randInputs) + dgst3 = _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + } diff --git a/ecc/bls12-381/fr/mimc/mimc.go b/ecc/bls12-381/fr/mimc/mimc.go index 71d8b8abb..d6912e0fb 100644 --- a/ecc/bls12-381/fr/mimc/mimc.go +++ b/ecc/bls12-381/fr/mimc/mimc.go @@ -17,13 +17,30 @@ import ( "golang.org/x/crypto/sha3" ) +// FieldHasher is an interface for a hash function that operates on field elements type FieldHasher interface { hash.StateStorer + + // WriteElement adds a field element to the running hash. WriteElement(e fr.Element) + + // SumElement returns the current hash as a field element. SumElement() fr.Element -} -var _ FieldHasher = NewMiMC() + // SumElements returns the current hash as a field element, + // after hashing all the provided elements (in addition to the already hashed ones). + // This is a convenience method to avoid multiple calls to WriteElement + // followed by a call to SumElement. + // It is equivalent to: + // for _, e := range elems { + // h.WriteElement(e) + // } + // return h. SumElement() + // + // This avoids copying the elements into the data slice and + // is more efficient. + SumElements([]fr.Element) fr.Element +} func init() { hash.RegisterHash(hash.MIMC_BLS12_381, func() stdhash.Hash { @@ -61,6 +78,17 @@ func GetConstants() []big.Int { return res } +// NewFieldHasher returns a FieldHasher (works with typed field elements, not bytes) +func NewFieldHasher(opts ...Option) FieldHasher { + r := NewMiMC(opts...) + return r.(FieldHasher) +} + +// NewBinaryHasher returns a hash.StateStorer (works with bytes, not typed field elements) +func NewBinaryHasher(opts ...Option) hash.StateStorer { + return NewMiMC(opts...) +} + // NewMiMC returns a MiMC implementation, pure Go reference implementation. func NewMiMC(opts ...Option) FieldHasher { d := new(digest) @@ -164,6 +192,32 @@ func (d *digest) checksum() fr.Element { return d.h } +// SumElements returns the current hash as a field element, +// after hashing all the provided elements (in addition to the already hashed ones). +// This is a convenience method to avoid multiple calls to WriteElement +// followed by a call to SumElement. +// It is equivalent to: +// +// for _, e := range elems { +// h.WriteElement(e) +// } +// return h. SumElement() +// +// This avoids copying the elements into the data slice and +// is more efficient. +func (d *digest) SumElements(elems []fr.Element) fr.Element { + for i := range d.data { + r := d.encrypt(d.data[i]) + d.h.Add(&r, &d.h).Add(&d.h, &d.data[i]) + } + for i := range elems { + r := d.encrypt(elems[i]) + d.h.Add(&r, &d.h).Add(&d.h, &elems[i]) + } + d.data = d.data[:0] + return d.h +} + // plain execution of a mimc run // m: message // k: encryption key diff --git a/ecc/bls12-381/fr/mimc/mimc_test.go b/ecc/bls12-381/fr/mimc/mimc_test.go index 98ed618ec..808f0cb2d 100644 --- a/ecc/bls12-381/fr/mimc/mimc_test.go +++ b/ecc/bls12-381/fr/mimc/mimc_test.go @@ -120,9 +120,9 @@ func TestSetState(t *testing.T) { func TestFieldHasher(t *testing.T) { assert := require.New(t) - h1 := mimc.NewMiMC() - h2 := mimc.NewMiMC() - h3 := mimc.NewMiMC() + h1 := mimc.NewFieldHasher() + h2 := mimc.NewFieldHasher() + h3 := mimc.NewFieldHasher() randInputs := make(fr.Vector, 10) randInputs.MustSetRandom() @@ -142,4 +142,11 @@ func TestFieldHasher(t *testing.T) { _dgst3 := h3.SumElement() dgst3 := _dgst3.Bytes() assert.Equal(dgst1, dgst3[:], "hashes do not match") + + // test SumElements + h3.Reset() + _dgst3 = h3.SumElements(randInputs) + dgst3 = _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + } diff --git a/ecc/bls24-315/fr/mimc/mimc.go b/ecc/bls24-315/fr/mimc/mimc.go index 833e90cb8..e8fa5d9a7 100644 --- a/ecc/bls24-315/fr/mimc/mimc.go +++ b/ecc/bls24-315/fr/mimc/mimc.go @@ -17,13 +17,30 @@ import ( "golang.org/x/crypto/sha3" ) +// FieldHasher is an interface for a hash function that operates on field elements type FieldHasher interface { hash.StateStorer + + // WriteElement adds a field element to the running hash. WriteElement(e fr.Element) + + // SumElement returns the current hash as a field element. SumElement() fr.Element -} -var _ FieldHasher = NewMiMC() + // SumElements returns the current hash as a field element, + // after hashing all the provided elements (in addition to the already hashed ones). + // This is a convenience method to avoid multiple calls to WriteElement + // followed by a call to SumElement. + // It is equivalent to: + // for _, e := range elems { + // h.WriteElement(e) + // } + // return h. SumElement() + // + // This avoids copying the elements into the data slice and + // is more efficient. + SumElements([]fr.Element) fr.Element +} func init() { hash.RegisterHash(hash.MIMC_BLS24_315, func() stdhash.Hash { @@ -61,6 +78,17 @@ func GetConstants() []big.Int { return res } +// NewFieldHasher returns a FieldHasher (works with typed field elements, not bytes) +func NewFieldHasher(opts ...Option) FieldHasher { + r := NewMiMC(opts...) + return r.(FieldHasher) +} + +// NewBinaryHasher returns a hash.StateStorer (works with bytes, not typed field elements) +func NewBinaryHasher(opts ...Option) hash.StateStorer { + return NewMiMC(opts...) +} + // NewMiMC returns a MiMC implementation, pure Go reference implementation. func NewMiMC(opts ...Option) FieldHasher { d := new(digest) @@ -164,6 +192,32 @@ func (d *digest) checksum() fr.Element { return d.h } +// SumElements returns the current hash as a field element, +// after hashing all the provided elements (in addition to the already hashed ones). +// This is a convenience method to avoid multiple calls to WriteElement +// followed by a call to SumElement. +// It is equivalent to: +// +// for _, e := range elems { +// h.WriteElement(e) +// } +// return h. SumElement() +// +// This avoids copying the elements into the data slice and +// is more efficient. +func (d *digest) SumElements(elems []fr.Element) fr.Element { + for i := range d.data { + r := d.encrypt(d.data[i]) + d.h.Add(&r, &d.h).Add(&d.h, &d.data[i]) + } + for i := range elems { + r := d.encrypt(elems[i]) + d.h.Add(&r, &d.h).Add(&d.h, &elems[i]) + } + d.data = d.data[:0] + return d.h +} + // plain execution of a mimc run // m: message // k: encryption key diff --git a/ecc/bls24-315/fr/mimc/mimc_test.go b/ecc/bls24-315/fr/mimc/mimc_test.go index 95a8d9673..d6f36855f 100644 --- a/ecc/bls24-315/fr/mimc/mimc_test.go +++ b/ecc/bls24-315/fr/mimc/mimc_test.go @@ -120,9 +120,9 @@ func TestSetState(t *testing.T) { func TestFieldHasher(t *testing.T) { assert := require.New(t) - h1 := mimc.NewMiMC() - h2 := mimc.NewMiMC() - h3 := mimc.NewMiMC() + h1 := mimc.NewFieldHasher() + h2 := mimc.NewFieldHasher() + h3 := mimc.NewFieldHasher() randInputs := make(fr.Vector, 10) randInputs.MustSetRandom() @@ -142,4 +142,11 @@ func TestFieldHasher(t *testing.T) { _dgst3 := h3.SumElement() dgst3 := _dgst3.Bytes() assert.Equal(dgst1, dgst3[:], "hashes do not match") + + // test SumElements + h3.Reset() + _dgst3 = h3.SumElements(randInputs) + dgst3 = _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + } diff --git a/ecc/bls24-317/fr/mimc/mimc.go b/ecc/bls24-317/fr/mimc/mimc.go index ef166a228..b2329216b 100644 --- a/ecc/bls24-317/fr/mimc/mimc.go +++ b/ecc/bls24-317/fr/mimc/mimc.go @@ -17,13 +17,30 @@ import ( "golang.org/x/crypto/sha3" ) +// FieldHasher is an interface for a hash function that operates on field elements type FieldHasher interface { hash.StateStorer + + // WriteElement adds a field element to the running hash. WriteElement(e fr.Element) + + // SumElement returns the current hash as a field element. SumElement() fr.Element -} -var _ FieldHasher = NewMiMC() + // SumElements returns the current hash as a field element, + // after hashing all the provided elements (in addition to the already hashed ones). + // This is a convenience method to avoid multiple calls to WriteElement + // followed by a call to SumElement. + // It is equivalent to: + // for _, e := range elems { + // h.WriteElement(e) + // } + // return h. SumElement() + // + // This avoids copying the elements into the data slice and + // is more efficient. + SumElements([]fr.Element) fr.Element +} func init() { hash.RegisterHash(hash.MIMC_BLS24_317, func() stdhash.Hash { @@ -61,6 +78,17 @@ func GetConstants() []big.Int { return res } +// NewFieldHasher returns a FieldHasher (works with typed field elements, not bytes) +func NewFieldHasher(opts ...Option) FieldHasher { + r := NewMiMC(opts...) + return r.(FieldHasher) +} + +// NewBinaryHasher returns a hash.StateStorer (works with bytes, not typed field elements) +func NewBinaryHasher(opts ...Option) hash.StateStorer { + return NewMiMC(opts...) +} + // NewMiMC returns a MiMC implementation, pure Go reference implementation. func NewMiMC(opts ...Option) FieldHasher { d := new(digest) @@ -164,6 +192,32 @@ func (d *digest) checksum() fr.Element { return d.h } +// SumElements returns the current hash as a field element, +// after hashing all the provided elements (in addition to the already hashed ones). +// This is a convenience method to avoid multiple calls to WriteElement +// followed by a call to SumElement. +// It is equivalent to: +// +// for _, e := range elems { +// h.WriteElement(e) +// } +// return h. SumElement() +// +// This avoids copying the elements into the data slice and +// is more efficient. +func (d *digest) SumElements(elems []fr.Element) fr.Element { + for i := range d.data { + r := d.encrypt(d.data[i]) + d.h.Add(&r, &d.h).Add(&d.h, &d.data[i]) + } + for i := range elems { + r := d.encrypt(elems[i]) + d.h.Add(&r, &d.h).Add(&d.h, &elems[i]) + } + d.data = d.data[:0] + return d.h +} + // plain execution of a mimc run // m: message // k: encryption key diff --git a/ecc/bls24-317/fr/mimc/mimc_test.go b/ecc/bls24-317/fr/mimc/mimc_test.go index 0eba16c08..ed41c62bc 100644 --- a/ecc/bls24-317/fr/mimc/mimc_test.go +++ b/ecc/bls24-317/fr/mimc/mimc_test.go @@ -120,9 +120,9 @@ func TestSetState(t *testing.T) { func TestFieldHasher(t *testing.T) { assert := require.New(t) - h1 := mimc.NewMiMC() - h2 := mimc.NewMiMC() - h3 := mimc.NewMiMC() + h1 := mimc.NewFieldHasher() + h2 := mimc.NewFieldHasher() + h3 := mimc.NewFieldHasher() randInputs := make(fr.Vector, 10) randInputs.MustSetRandom() @@ -142,4 +142,11 @@ func TestFieldHasher(t *testing.T) { _dgst3 := h3.SumElement() dgst3 := _dgst3.Bytes() assert.Equal(dgst1, dgst3[:], "hashes do not match") + + // test SumElements + h3.Reset() + _dgst3 = h3.SumElements(randInputs) + dgst3 = _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + } diff --git a/ecc/bn254/fr/mimc/mimc.go b/ecc/bn254/fr/mimc/mimc.go index 8a20b0bbb..73e9f741d 100644 --- a/ecc/bn254/fr/mimc/mimc.go +++ b/ecc/bn254/fr/mimc/mimc.go @@ -17,13 +17,30 @@ import ( "golang.org/x/crypto/sha3" ) +// FieldHasher is an interface for a hash function that operates on field elements type FieldHasher interface { hash.StateStorer + + // WriteElement adds a field element to the running hash. WriteElement(e fr.Element) + + // SumElement returns the current hash as a field element. SumElement() fr.Element -} -var _ FieldHasher = NewMiMC() + // SumElements returns the current hash as a field element, + // after hashing all the provided elements (in addition to the already hashed ones). + // This is a convenience method to avoid multiple calls to WriteElement + // followed by a call to SumElement. + // It is equivalent to: + // for _, e := range elems { + // h.WriteElement(e) + // } + // return h. SumElement() + // + // This avoids copying the elements into the data slice and + // is more efficient. + SumElements([]fr.Element) fr.Element +} func init() { hash.RegisterHash(hash.MIMC_BN254, func() stdhash.Hash { @@ -61,6 +78,17 @@ func GetConstants() []big.Int { return res } +// NewFieldHasher returns a FieldHasher (works with typed field elements, not bytes) +func NewFieldHasher(opts ...Option) FieldHasher { + r := NewMiMC(opts...) + return r.(FieldHasher) +} + +// NewBinaryHasher returns a hash.StateStorer (works with bytes, not typed field elements) +func NewBinaryHasher(opts ...Option) hash.StateStorer { + return NewMiMC(opts...) +} + // NewMiMC returns a MiMC implementation, pure Go reference implementation. func NewMiMC(opts ...Option) FieldHasher { d := new(digest) @@ -164,6 +192,32 @@ func (d *digest) checksum() fr.Element { return d.h } +// SumElements returns the current hash as a field element, +// after hashing all the provided elements (in addition to the already hashed ones). +// This is a convenience method to avoid multiple calls to WriteElement +// followed by a call to SumElement. +// It is equivalent to: +// +// for _, e := range elems { +// h.WriteElement(e) +// } +// return h. SumElement() +// +// This avoids copying the elements into the data slice and +// is more efficient. +func (d *digest) SumElements(elems []fr.Element) fr.Element { + for i := range d.data { + r := d.encrypt(d.data[i]) + d.h.Add(&r, &d.h).Add(&d.h, &d.data[i]) + } + for i := range elems { + r := d.encrypt(elems[i]) + d.h.Add(&r, &d.h).Add(&d.h, &elems[i]) + } + d.data = d.data[:0] + return d.h +} + // plain execution of a mimc run // m: message // k: encryption key diff --git a/ecc/bn254/fr/mimc/mimc_test.go b/ecc/bn254/fr/mimc/mimc_test.go index 3bc739b45..c4737fda4 100644 --- a/ecc/bn254/fr/mimc/mimc_test.go +++ b/ecc/bn254/fr/mimc/mimc_test.go @@ -120,9 +120,9 @@ func TestSetState(t *testing.T) { func TestFieldHasher(t *testing.T) { assert := require.New(t) - h1 := mimc.NewMiMC() - h2 := mimc.NewMiMC() - h3 := mimc.NewMiMC() + h1 := mimc.NewFieldHasher() + h2 := mimc.NewFieldHasher() + h3 := mimc.NewFieldHasher() randInputs := make(fr.Vector, 10) randInputs.MustSetRandom() @@ -142,4 +142,11 @@ func TestFieldHasher(t *testing.T) { _dgst3 := h3.SumElement() dgst3 := _dgst3.Bytes() assert.Equal(dgst1, dgst3[:], "hashes do not match") + + // test SumElements + h3.Reset() + _dgst3 = h3.SumElements(randInputs) + dgst3 = _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + } diff --git a/ecc/bw6-633/fr/mimc/mimc.go b/ecc/bw6-633/fr/mimc/mimc.go index 988c1b8c9..7f3cbe465 100644 --- a/ecc/bw6-633/fr/mimc/mimc.go +++ b/ecc/bw6-633/fr/mimc/mimc.go @@ -17,13 +17,30 @@ import ( "golang.org/x/crypto/sha3" ) +// FieldHasher is an interface for a hash function that operates on field elements type FieldHasher interface { hash.StateStorer + + // WriteElement adds a field element to the running hash. WriteElement(e fr.Element) + + // SumElement returns the current hash as a field element. SumElement() fr.Element -} -var _ FieldHasher = NewMiMC() + // SumElements returns the current hash as a field element, + // after hashing all the provided elements (in addition to the already hashed ones). + // This is a convenience method to avoid multiple calls to WriteElement + // followed by a call to SumElement. + // It is equivalent to: + // for _, e := range elems { + // h.WriteElement(e) + // } + // return h. SumElement() + // + // This avoids copying the elements into the data slice and + // is more efficient. + SumElements([]fr.Element) fr.Element +} func init() { hash.RegisterHash(hash.MIMC_BW6_633, func() stdhash.Hash { @@ -61,6 +78,17 @@ func GetConstants() []big.Int { return res } +// NewFieldHasher returns a FieldHasher (works with typed field elements, not bytes) +func NewFieldHasher(opts ...Option) FieldHasher { + r := NewMiMC(opts...) + return r.(FieldHasher) +} + +// NewBinaryHasher returns a hash.StateStorer (works with bytes, not typed field elements) +func NewBinaryHasher(opts ...Option) hash.StateStorer { + return NewMiMC(opts...) +} + // NewMiMC returns a MiMC implementation, pure Go reference implementation. func NewMiMC(opts ...Option) FieldHasher { d := new(digest) @@ -164,6 +192,32 @@ func (d *digest) checksum() fr.Element { return d.h } +// SumElements returns the current hash as a field element, +// after hashing all the provided elements (in addition to the already hashed ones). +// This is a convenience method to avoid multiple calls to WriteElement +// followed by a call to SumElement. +// It is equivalent to: +// +// for _, e := range elems { +// h.WriteElement(e) +// } +// return h. SumElement() +// +// This avoids copying the elements into the data slice and +// is more efficient. +func (d *digest) SumElements(elems []fr.Element) fr.Element { + for i := range d.data { + r := d.encrypt(d.data[i]) + d.h.Add(&r, &d.h).Add(&d.h, &d.data[i]) + } + for i := range elems { + r := d.encrypt(elems[i]) + d.h.Add(&r, &d.h).Add(&d.h, &elems[i]) + } + d.data = d.data[:0] + return d.h +} + // plain execution of a mimc run // m: message // k: encryption key diff --git a/ecc/bw6-633/fr/mimc/mimc_test.go b/ecc/bw6-633/fr/mimc/mimc_test.go index 19b5185c2..13694fe58 100644 --- a/ecc/bw6-633/fr/mimc/mimc_test.go +++ b/ecc/bw6-633/fr/mimc/mimc_test.go @@ -120,9 +120,9 @@ func TestSetState(t *testing.T) { func TestFieldHasher(t *testing.T) { assert := require.New(t) - h1 := mimc.NewMiMC() - h2 := mimc.NewMiMC() - h3 := mimc.NewMiMC() + h1 := mimc.NewFieldHasher() + h2 := mimc.NewFieldHasher() + h3 := mimc.NewFieldHasher() randInputs := make(fr.Vector, 10) randInputs.MustSetRandom() @@ -142,4 +142,11 @@ func TestFieldHasher(t *testing.T) { _dgst3 := h3.SumElement() dgst3 := _dgst3.Bytes() assert.Equal(dgst1, dgst3[:], "hashes do not match") + + // test SumElements + h3.Reset() + _dgst3 = h3.SumElements(randInputs) + dgst3 = _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + } diff --git a/ecc/bw6-761/fr/mimc/mimc.go b/ecc/bw6-761/fr/mimc/mimc.go index 1bd0d3a98..33770aec7 100644 --- a/ecc/bw6-761/fr/mimc/mimc.go +++ b/ecc/bw6-761/fr/mimc/mimc.go @@ -17,13 +17,30 @@ import ( "golang.org/x/crypto/sha3" ) +// FieldHasher is an interface for a hash function that operates on field elements type FieldHasher interface { hash.StateStorer + + // WriteElement adds a field element to the running hash. WriteElement(e fr.Element) + + // SumElement returns the current hash as a field element. SumElement() fr.Element -} -var _ FieldHasher = NewMiMC() + // SumElements returns the current hash as a field element, + // after hashing all the provided elements (in addition to the already hashed ones). + // This is a convenience method to avoid multiple calls to WriteElement + // followed by a call to SumElement. + // It is equivalent to: + // for _, e := range elems { + // h.WriteElement(e) + // } + // return h. SumElement() + // + // This avoids copying the elements into the data slice and + // is more efficient. + SumElements([]fr.Element) fr.Element +} func init() { hash.RegisterHash(hash.MIMC_BW6_761, func() stdhash.Hash { @@ -61,6 +78,17 @@ func GetConstants() []big.Int { return res } +// NewFieldHasher returns a FieldHasher (works with typed field elements, not bytes) +func NewFieldHasher(opts ...Option) FieldHasher { + r := NewMiMC(opts...) + return r.(FieldHasher) +} + +// NewBinaryHasher returns a hash.StateStorer (works with bytes, not typed field elements) +func NewBinaryHasher(opts ...Option) hash.StateStorer { + return NewMiMC(opts...) +} + // NewMiMC returns a MiMC implementation, pure Go reference implementation. func NewMiMC(opts ...Option) FieldHasher { d := new(digest) @@ -164,6 +192,32 @@ func (d *digest) checksum() fr.Element { return d.h } +// SumElements returns the current hash as a field element, +// after hashing all the provided elements (in addition to the already hashed ones). +// This is a convenience method to avoid multiple calls to WriteElement +// followed by a call to SumElement. +// It is equivalent to: +// +// for _, e := range elems { +// h.WriteElement(e) +// } +// return h. SumElement() +// +// This avoids copying the elements into the data slice and +// is more efficient. +func (d *digest) SumElements(elems []fr.Element) fr.Element { + for i := range d.data { + r := d.encrypt(d.data[i]) + d.h.Add(&r, &d.h).Add(&d.h, &d.data[i]) + } + for i := range elems { + r := d.encrypt(elems[i]) + d.h.Add(&r, &d.h).Add(&d.h, &elems[i]) + } + d.data = d.data[:0] + return d.h +} + // plain execution of a mimc run // m: message // k: encryption key diff --git a/ecc/bw6-761/fr/mimc/mimc_test.go b/ecc/bw6-761/fr/mimc/mimc_test.go index e0e7b21c8..2d44034b0 100644 --- a/ecc/bw6-761/fr/mimc/mimc_test.go +++ b/ecc/bw6-761/fr/mimc/mimc_test.go @@ -120,9 +120,9 @@ func TestSetState(t *testing.T) { func TestFieldHasher(t *testing.T) { assert := require.New(t) - h1 := mimc.NewMiMC() - h2 := mimc.NewMiMC() - h3 := mimc.NewMiMC() + h1 := mimc.NewFieldHasher() + h2 := mimc.NewFieldHasher() + h3 := mimc.NewFieldHasher() randInputs := make(fr.Vector, 10) randInputs.MustSetRandom() @@ -142,4 +142,11 @@ func TestFieldHasher(t *testing.T) { _dgst3 := h3.SumElement() dgst3 := _dgst3.Bytes() assert.Equal(dgst1, dgst3[:], "hashes do not match") + + // test SumElements + h3.Reset() + _dgst3 = h3.SumElements(randInputs) + dgst3 = _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + } diff --git a/ecc/grumpkin/fr/mimc/mimc.go b/ecc/grumpkin/fr/mimc/mimc.go index 05da818b3..a47790cab 100644 --- a/ecc/grumpkin/fr/mimc/mimc.go +++ b/ecc/grumpkin/fr/mimc/mimc.go @@ -17,13 +17,30 @@ import ( "golang.org/x/crypto/sha3" ) +// FieldHasher is an interface for a hash function that operates on field elements type FieldHasher interface { hash.StateStorer + + // WriteElement adds a field element to the running hash. WriteElement(e fr.Element) + + // SumElement returns the current hash as a field element. SumElement() fr.Element -} -var _ FieldHasher = NewMiMC() + // SumElements returns the current hash as a field element, + // after hashing all the provided elements (in addition to the already hashed ones). + // This is a convenience method to avoid multiple calls to WriteElement + // followed by a call to SumElement. + // It is equivalent to: + // for _, e := range elems { + // h.WriteElement(e) + // } + // return h. SumElement() + // + // This avoids copying the elements into the data slice and + // is more efficient. + SumElements([]fr.Element) fr.Element +} func init() { hash.RegisterHash(hash.MIMC_GRUMPKIN, func() stdhash.Hash { @@ -61,6 +78,17 @@ func GetConstants() []big.Int { return res } +// NewFieldHasher returns a FieldHasher (works with typed field elements, not bytes) +func NewFieldHasher(opts ...Option) FieldHasher { + r := NewMiMC(opts...) + return r.(FieldHasher) +} + +// NewBinaryHasher returns a hash.StateStorer (works with bytes, not typed field elements) +func NewBinaryHasher(opts ...Option) hash.StateStorer { + return NewMiMC(opts...) +} + // NewMiMC returns a MiMC implementation, pure Go reference implementation. func NewMiMC(opts ...Option) FieldHasher { d := new(digest) @@ -164,6 +192,32 @@ func (d *digest) checksum() fr.Element { return d.h } +// SumElements returns the current hash as a field element, +// after hashing all the provided elements (in addition to the already hashed ones). +// This is a convenience method to avoid multiple calls to WriteElement +// followed by a call to SumElement. +// It is equivalent to: +// +// for _, e := range elems { +// h.WriteElement(e) +// } +// return h. SumElement() +// +// This avoids copying the elements into the data slice and +// is more efficient. +func (d *digest) SumElements(elems []fr.Element) fr.Element { + for i := range d.data { + r := d.encrypt(d.data[i]) + d.h.Add(&r, &d.h).Add(&d.h, &d.data[i]) + } + for i := range elems { + r := d.encrypt(elems[i]) + d.h.Add(&r, &d.h).Add(&d.h, &elems[i]) + } + d.data = d.data[:0] + return d.h +} + // plain execution of a mimc run // m: message // k: encryption key diff --git a/ecc/grumpkin/fr/mimc/mimc_test.go b/ecc/grumpkin/fr/mimc/mimc_test.go index bc12e8c0f..b26236c1b 100644 --- a/ecc/grumpkin/fr/mimc/mimc_test.go +++ b/ecc/grumpkin/fr/mimc/mimc_test.go @@ -120,9 +120,9 @@ func TestSetState(t *testing.T) { func TestFieldHasher(t *testing.T) { assert := require.New(t) - h1 := mimc.NewMiMC() - h2 := mimc.NewMiMC() - h3 := mimc.NewMiMC() + h1 := mimc.NewFieldHasher() + h2 := mimc.NewFieldHasher() + h3 := mimc.NewFieldHasher() randInputs := make(fr.Vector, 10) randInputs.MustSetRandom() @@ -142,4 +142,11 @@ func TestFieldHasher(t *testing.T) { _dgst3 := h3.SumElement() dgst3 := _dgst3.Bytes() assert.Equal(dgst1, dgst3[:], "hashes do not match") + + // test SumElements + h3.Reset() + _dgst3 = h3.SumElements(randInputs) + dgst3 = _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + } diff --git a/internal/generator/crypto/hash/mimc/template/mimc.go.tmpl b/internal/generator/crypto/hash/mimc/template/mimc.go.tmpl index c96c50ad6..8cdd8f2ef 100644 --- a/internal/generator/crypto/hash/mimc/template/mimc.go.tmpl +++ b/internal/generator/crypto/hash/mimc/template/mimc.go.tmpl @@ -10,12 +10,30 @@ import ( "golang.org/x/crypto/sha3" ) +// FieldHasher is an interface for a hash function that operates on field elements type FieldHasher interface { hash.StateStorer + + // WriteElement adds a field element to the running hash. WriteElement(e fr.Element) + + // SumElement returns the current hash as a field element. SumElement() fr.Element + + // SumElements returns the current hash as a field element, + // after hashing all the provided elements (in addition to the already hashed ones). + // This is a convenience method to avoid multiple calls to WriteElement + // followed by a call to SumElement. + // It is equivalent to: + // for _, e := range elems { + // h.WriteElement(e) + // } + // return h. SumElement() + // + // This avoids copying the elements into the data slice and + // is more efficient. + SumElements([]fr.Element) fr.Element } -var _ FieldHasher = NewMiMC() func init() { hash.RegisterHash(hash.MIMC_{{ .EnumID }}, func() stdhash.Hash { @@ -69,6 +87,17 @@ func GetConstants() []big.Int { return res } +// NewFieldHasher returns a FieldHasher (works with typed field elements, not bytes) +func NewFieldHasher(opts ...Option) FieldHasher { + r := NewMiMC(opts...) + return r.(FieldHasher) +} + +// NewBinaryHasher returns a hash.StateStorer (works with bytes, not typed field elements) +func NewBinaryHasher(opts ...Option) hash.StateStorer { + return NewMiMC(opts...) +} + // NewMiMC returns a MiMC implementation, pure Go reference implementation. func NewMiMC(opts ...Option) FieldHasher { d := new(digest) @@ -173,6 +202,30 @@ func (d *digest) checksum() fr.Element { return d.h } +// SumElements returns the current hash as a field element, +// after hashing all the provided elements (in addition to the already hashed ones). +// This is a convenience method to avoid multiple calls to WriteElement +// followed by a call to SumElement. +// It is equivalent to: +// for _, e := range elems { +// h.WriteElement(e) +// } +// return h. SumElement() +// +// This avoids copying the elements into the data slice and +// is more efficient. +func (d *digest) SumElements(elems []fr.Element) fr.Element { + for i := range d.data { + r := d.encrypt(d.data[i]) + d.h.Add(&r, &d.h).Add(&d.h, &d.data[i]) + } + for i := range elems { + r := d.encrypt(elems[i]) + d.h.Add(&r, &d.h).Add(&d.h, &elems[i]) + } + d.data = d.data[:0] + return d.h +} {{ if eq .Name "bls12-377" }} // plain execution of a mimc run diff --git a/internal/generator/crypto/hash/mimc/template/tests/mimc_test.go.tmpl b/internal/generator/crypto/hash/mimc/template/tests/mimc_test.go.tmpl index bde054711..4a70661ed 100644 --- a/internal/generator/crypto/hash/mimc/template/tests/mimc_test.go.tmpl +++ b/internal/generator/crypto/hash/mimc/template/tests/mimc_test.go.tmpl @@ -113,9 +113,9 @@ func TestSetState(t *testing.T) { func TestFieldHasher(t *testing.T) { assert := require.New(t) - h1 := mimc.NewMiMC() - h2 := mimc.NewMiMC() - h3 := mimc.NewMiMC() + h1 := mimc.NewFieldHasher() + h2 := mimc.NewFieldHasher() + h3 := mimc.NewFieldHasher() randInputs := make(fr.Vector, 10) randInputs.MustSetRandom() @@ -135,4 +135,11 @@ func TestFieldHasher(t *testing.T) { _dgst3 := h3.SumElement() dgst3 := _dgst3.Bytes() assert.Equal(dgst1, dgst3[:], "hashes do not match") + + // test SumElements + h3.Reset() + _dgst3 = h3.SumElements(randInputs) + dgst3 = _dgst3.Bytes() + assert.Equal(dgst1, dgst3[:], "hashes do not match") + } \ No newline at end of file From 69a8557bc7782851726aa1b2ee6173775cc4f6be Mon Sep 17 00:00:00 2001 From: "Yao J. Galteland" <73404195+YaoJGalteland@users.noreply.github.com> Date: Tue, 9 Sep 2025 19:13:57 +0200 Subject: [PATCH 3/4] Feat/newdomain optimization (#737) Co-authored-by: Gautam Botrel Co-authored-by: Ivo Kubjas --- ecc/bls12-377/fr/fft/domain.go | 111 ++++++++++++++++- ecc/bls12-377/fr/fft/domain_test.go | 98 +++++++++++++++ ecc/bls12-377/fr/fft/options.go | 9 ++ ecc/bls12-381/fr/fft/domain.go | 111 ++++++++++++++++- ecc/bls12-381/fr/fft/domain_test.go | 98 +++++++++++++++ ecc/bls12-381/fr/fft/options.go | 9 ++ ecc/bls24-315/fr/fft/domain.go | 111 ++++++++++++++++- ecc/bls24-315/fr/fft/domain_test.go | 98 +++++++++++++++ ecc/bls24-315/fr/fft/options.go | 9 ++ ecc/bls24-317/fr/fft/domain.go | 111 ++++++++++++++++- ecc/bls24-317/fr/fft/domain_test.go | 98 +++++++++++++++ ecc/bls24-317/fr/fft/options.go | 9 ++ ecc/bn254/fr/fft/domain.go | 111 ++++++++++++++++- ecc/bn254/fr/fft/domain_test.go | 98 +++++++++++++++ ecc/bn254/fr/fft/options.go | 9 ++ ecc/bw6-633/fr/fft/domain.go | 111 ++++++++++++++++- ecc/bw6-633/fr/fft/domain_test.go | 98 +++++++++++++++ ecc/bw6-633/fr/fft/options.go | 9 ++ ecc/bw6-761/fr/fft/domain.go | 111 ++++++++++++++++- ecc/bw6-761/fr/fft/domain_test.go | 98 +++++++++++++++ ecc/bw6-761/fr/fft/options.go | 9 ++ field/babybear/extensions/e2_test.go | 33 +++++ field/babybear/fft/domain.go | 111 ++++++++++++++++- field/babybear/fft/domain_test.go | 98 +++++++++++++++ field/babybear/fft/options.go | 9 ++ .../templates/extensions/e2_test.go.tmpl | 40 ++++++ .../internal/templates/fft/domain.go.tmpl | 114 +++++++++++++++++- .../internal/templates/fft/options.go.tmpl | 9 ++ .../templates/fft/tests/domain.go.tmpl | 104 +++++++++++++++- field/goldilocks/extensions/e2_test.go | 33 +++++ field/goldilocks/fft/domain.go | 111 ++++++++++++++++- field/goldilocks/fft/domain_test.go | 98 +++++++++++++++ field/goldilocks/fft/options.go | 9 ++ field/koalabear/extensions/e2_test.go | 33 +++++ field/koalabear/fft/domain.go | 111 ++++++++++++++++- field/koalabear/fft/domain_test.go | 98 +++++++++++++++ field/koalabear/fft/options.go | 9 ++ 37 files changed, 2510 insertions(+), 36 deletions(-) diff --git a/ecc/bls12-377/fr/fft/domain.go b/ecc/bls12-377/fr/fft/domain.go index 1d44a171a..3e6c9ad8c 100644 --- a/ecc/bls12-377/fr/fft/domain.go +++ b/ecc/bls12-377/fr/fft/domain.go @@ -13,6 +13,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" @@ -61,11 +62,115 @@ func GeneratorFullMultiplicativeGroup() fr.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen fr.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) diff --git a/ecc/bls12-377/fr/fft/domain_test.go b/ecc/bls12-377/fr/fft/domain_test.go index 7049120e6..8b539cd50 100644 --- a/ecc/bls12-377/fr/fft/domain_test.go +++ b/ecc/bls12-377/fr/fft/domain_test.go @@ -8,7 +8,11 @@ package fft import ( "bytes" "reflect" + "runtime" "testing" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/stretchr/testify/require" ) func TestDomainSerialization(t *testing.T) { @@ -34,3 +38,97 @@ func TestDomainSerialization(t *testing.T) { t.Fatal("Domain.SetBytes(Bytes()) failed") } } + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := fr.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/ecc/bls12-377/fr/fft/options.go b/ecc/bls12-377/fr/fft/options.go index a562b0ae7..e4ed53672 100644 --- a/ecc/bls12-377/fr/fft/options.go +++ b/ecc/bls12-377/fr/fft/options.go @@ -63,6 +63,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *fr.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -81,11 +82,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) diff --git a/ecc/bls12-381/fr/fft/domain.go b/ecc/bls12-381/fr/fft/domain.go index 03f1fbf49..65e73d8e0 100644 --- a/ecc/bls12-381/fr/fft/domain.go +++ b/ecc/bls12-381/fr/fft/domain.go @@ -13,6 +13,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" @@ -61,11 +62,115 @@ func GeneratorFullMultiplicativeGroup() fr.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen fr.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) diff --git a/ecc/bls12-381/fr/fft/domain_test.go b/ecc/bls12-381/fr/fft/domain_test.go index 7049120e6..9f21b00e2 100644 --- a/ecc/bls12-381/fr/fft/domain_test.go +++ b/ecc/bls12-381/fr/fft/domain_test.go @@ -8,7 +8,11 @@ package fft import ( "bytes" "reflect" + "runtime" "testing" + + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/stretchr/testify/require" ) func TestDomainSerialization(t *testing.T) { @@ -34,3 +38,97 @@ func TestDomainSerialization(t *testing.T) { t.Fatal("Domain.SetBytes(Bytes()) failed") } } + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := fr.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/ecc/bls12-381/fr/fft/options.go b/ecc/bls12-381/fr/fft/options.go index e705081cd..a775c3f48 100644 --- a/ecc/bls12-381/fr/fft/options.go +++ b/ecc/bls12-381/fr/fft/options.go @@ -63,6 +63,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *fr.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -81,11 +82,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) diff --git a/ecc/bls24-315/fr/fft/domain.go b/ecc/bls24-315/fr/fft/domain.go index 1b8860e7d..fd9f1c8d2 100644 --- a/ecc/bls24-315/fr/fft/domain.go +++ b/ecc/bls24-315/fr/fft/domain.go @@ -13,6 +13,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" @@ -61,11 +62,115 @@ func GeneratorFullMultiplicativeGroup() fr.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen fr.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) diff --git a/ecc/bls24-315/fr/fft/domain_test.go b/ecc/bls24-315/fr/fft/domain_test.go index 7049120e6..35e1bce78 100644 --- a/ecc/bls24-315/fr/fft/domain_test.go +++ b/ecc/bls24-315/fr/fft/domain_test.go @@ -8,7 +8,11 @@ package fft import ( "bytes" "reflect" + "runtime" "testing" + + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/stretchr/testify/require" ) func TestDomainSerialization(t *testing.T) { @@ -34,3 +38,97 @@ func TestDomainSerialization(t *testing.T) { t.Fatal("Domain.SetBytes(Bytes()) failed") } } + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := fr.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/ecc/bls24-315/fr/fft/options.go b/ecc/bls24-315/fr/fft/options.go index 8538f4cda..c621412df 100644 --- a/ecc/bls24-315/fr/fft/options.go +++ b/ecc/bls24-315/fr/fft/options.go @@ -63,6 +63,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *fr.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -81,11 +82,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) diff --git a/ecc/bls24-317/fr/fft/domain.go b/ecc/bls24-317/fr/fft/domain.go index c3745da08..a2fee9121 100644 --- a/ecc/bls24-317/fr/fft/domain.go +++ b/ecc/bls24-317/fr/fft/domain.go @@ -13,6 +13,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" @@ -61,11 +62,115 @@ func GeneratorFullMultiplicativeGroup() fr.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen fr.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) diff --git a/ecc/bls24-317/fr/fft/domain_test.go b/ecc/bls24-317/fr/fft/domain_test.go index 7049120e6..346befd5a 100644 --- a/ecc/bls24-317/fr/fft/domain_test.go +++ b/ecc/bls24-317/fr/fft/domain_test.go @@ -8,7 +8,11 @@ package fft import ( "bytes" "reflect" + "runtime" "testing" + + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/stretchr/testify/require" ) func TestDomainSerialization(t *testing.T) { @@ -34,3 +38,97 @@ func TestDomainSerialization(t *testing.T) { t.Fatal("Domain.SetBytes(Bytes()) failed") } } + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := fr.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/ecc/bls24-317/fr/fft/options.go b/ecc/bls24-317/fr/fft/options.go index 9a7361935..6f629f63f 100644 --- a/ecc/bls24-317/fr/fft/options.go +++ b/ecc/bls24-317/fr/fft/options.go @@ -63,6 +63,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *fr.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -81,11 +82,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) diff --git a/ecc/bn254/fr/fft/domain.go b/ecc/bn254/fr/fft/domain.go index 5c9d3e545..fea6e366b 100644 --- a/ecc/bn254/fr/fft/domain.go +++ b/ecc/bn254/fr/fft/domain.go @@ -13,6 +13,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "github.com/consensys/gnark-crypto/ecc/bn254/fr" @@ -61,11 +62,115 @@ func GeneratorFullMultiplicativeGroup() fr.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen fr.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) diff --git a/ecc/bn254/fr/fft/domain_test.go b/ecc/bn254/fr/fft/domain_test.go index 7049120e6..f274817c0 100644 --- a/ecc/bn254/fr/fft/domain_test.go +++ b/ecc/bn254/fr/fft/domain_test.go @@ -8,7 +8,11 @@ package fft import ( "bytes" "reflect" + "runtime" "testing" + + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/stretchr/testify/require" ) func TestDomainSerialization(t *testing.T) { @@ -34,3 +38,97 @@ func TestDomainSerialization(t *testing.T) { t.Fatal("Domain.SetBytes(Bytes()) failed") } } + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := fr.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/ecc/bn254/fr/fft/options.go b/ecc/bn254/fr/fft/options.go index 87e5bae69..54ff79010 100644 --- a/ecc/bn254/fr/fft/options.go +++ b/ecc/bn254/fr/fft/options.go @@ -63,6 +63,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *fr.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -81,11 +82,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) diff --git a/ecc/bw6-633/fr/fft/domain.go b/ecc/bw6-633/fr/fft/domain.go index ab6cc41a9..6f83b4ea2 100644 --- a/ecc/bw6-633/fr/fft/domain.go +++ b/ecc/bw6-633/fr/fft/domain.go @@ -13,6 +13,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" @@ -61,11 +62,115 @@ func GeneratorFullMultiplicativeGroup() fr.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen fr.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) diff --git a/ecc/bw6-633/fr/fft/domain_test.go b/ecc/bw6-633/fr/fft/domain_test.go index 7049120e6..8cd438b87 100644 --- a/ecc/bw6-633/fr/fft/domain_test.go +++ b/ecc/bw6-633/fr/fft/domain_test.go @@ -8,7 +8,11 @@ package fft import ( "bytes" "reflect" + "runtime" "testing" + + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/stretchr/testify/require" ) func TestDomainSerialization(t *testing.T) { @@ -34,3 +38,97 @@ func TestDomainSerialization(t *testing.T) { t.Fatal("Domain.SetBytes(Bytes()) failed") } } + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := fr.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/ecc/bw6-633/fr/fft/options.go b/ecc/bw6-633/fr/fft/options.go index 3b9f572b4..2f2ee7e39 100644 --- a/ecc/bw6-633/fr/fft/options.go +++ b/ecc/bw6-633/fr/fft/options.go @@ -63,6 +63,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *fr.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -81,11 +82,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) diff --git a/ecc/bw6-761/fr/fft/domain.go b/ecc/bw6-761/fr/fft/domain.go index 079b9ada8..bc2569e6a 100644 --- a/ecc/bw6-761/fr/fft/domain.go +++ b/ecc/bw6-761/fr/fft/domain.go @@ -13,6 +13,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" @@ -61,11 +62,115 @@ func GeneratorFullMultiplicativeGroup() fr.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen fr.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) diff --git a/ecc/bw6-761/fr/fft/domain_test.go b/ecc/bw6-761/fr/fft/domain_test.go index 7049120e6..6752274e5 100644 --- a/ecc/bw6-761/fr/fft/domain_test.go +++ b/ecc/bw6-761/fr/fft/domain_test.go @@ -8,7 +8,11 @@ package fft import ( "bytes" "reflect" + "runtime" "testing" + + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/stretchr/testify/require" ) func TestDomainSerialization(t *testing.T) { @@ -34,3 +38,97 @@ func TestDomainSerialization(t *testing.T) { t.Fatal("Domain.SetBytes(Bytes()) failed") } } + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := fr.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/ecc/bw6-761/fr/fft/options.go b/ecc/bw6-761/fr/fft/options.go index 276471bd1..6bd158b94 100644 --- a/ecc/bw6-761/fr/fft/options.go +++ b/ecc/bw6-761/fr/fft/options.go @@ -63,6 +63,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *fr.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -81,11 +82,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) diff --git a/field/babybear/extensions/e2_test.go b/field/babybear/extensions/e2_test.go index 68178cd56..e95d62713 100644 --- a/field/babybear/extensions/e2_test.go +++ b/field/babybear/extensions/e2_test.go @@ -513,3 +513,36 @@ func TestE2Div(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } + +var modulus = fr.Modulus() + +// genFr generates an Fr element +func genFr() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var elmt fr.Element + // SetBigInt will reduce the value modulo the field order + // genParams.Rng is a math/rand.Rand which is not a cryptographically secure + // source of randomness. However, for property based testing, it is desirable + // to have a deterministic generator. + e := bigIntPool.Get().(*big.Int) + e.Rand(genParams.Rng, modulus) + + for i, w := range e.Bits() { + elmt[i] = uint32(w) + } + bigIntPool.Put(e) + + genResult := gopter.NewGenResult(elmt, gopter.NoShrinker) + return genResult + } +} + +// genE2 generates an E2 element +func genE2() gopter.Gen { + return gopter.CombineGens( + genFr(), + genFr(), + ).Map(func(values []interface{}) E2 { + return E2{A0: values[0].(fr.Element), A1: values[1].(fr.Element)} + }) +} diff --git a/field/babybear/fft/domain.go b/field/babybear/fft/domain.go index 49a6ab8af..cdb69ae2b 100644 --- a/field/babybear/fft/domain.go +++ b/field/babybear/fft/domain.go @@ -13,6 +13,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "github.com/consensys/gnark-crypto/field/babybear" @@ -61,11 +62,115 @@ func GeneratorFullMultiplicativeGroup() babybear.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen babybear.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) diff --git a/field/babybear/fft/domain_test.go b/field/babybear/fft/domain_test.go index 7049120e6..b9e8f1a38 100644 --- a/field/babybear/fft/domain_test.go +++ b/field/babybear/fft/domain_test.go @@ -8,7 +8,11 @@ package fft import ( "bytes" "reflect" + "runtime" "testing" + + "github.com/consensys/gnark-crypto/field/babybear" + "github.com/stretchr/testify/require" ) func TestDomainSerialization(t *testing.T) { @@ -34,3 +38,97 @@ func TestDomainSerialization(t *testing.T) { t.Fatal("Domain.SetBytes(Bytes()) failed") } } + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := babybear.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/field/babybear/fft/options.go b/field/babybear/fft/options.go index d6fcb9c15..a586629a7 100644 --- a/field/babybear/fft/options.go +++ b/field/babybear/fft/options.go @@ -63,6 +63,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *babybear.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -81,11 +82,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) diff --git a/field/generator/internal/templates/extensions/e2_test.go.tmpl b/field/generator/internal/templates/extensions/e2_test.go.tmpl index 502d21bd4..5be2f5a94 100644 --- a/field/generator/internal/templates/extensions/e2_test.go.tmpl +++ b/field/generator/internal/templates/extensions/e2_test.go.tmpl @@ -512,3 +512,43 @@ func TestE2Div(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } + + + + +var modulus = fr.Modulus() + +// genFr generates an Fr element +func genFr() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var elmt fr.Element + // SetBigInt will reduce the value modulo the field order + // genParams.Rng is a math/rand.Rand which is not a cryptographically secure + // source of randomness. However, for property based testing, it is desirable + // to have a deterministic generator. + e := bigIntPool.Get().(*big.Int) + e.Rand(genParams.Rng, modulus) + + for i, w := range e.Bits() { + {{- if or (eq .FF "babybear") (eq .FF "koalabear")}} + elmt[i] = uint32(w) + {{- else}} + elmt[i] = uint64(w) + {{- end}} + } + bigIntPool.Put(e) + + genResult := gopter.NewGenResult(elmt, gopter.NoShrinker) + return genResult + } +} + +// genE2 generates an E2 element +func genE2() gopter.Gen { + return gopter.CombineGens( + genFr(), + genFr(), + ).Map(func(values []interface{}) E2 { + return E2{A0: values[0].(fr.Element), A1: values[1].(fr.Element)} + }) +} diff --git a/field/generator/internal/templates/fft/domain.go.tmpl b/field/generator/internal/templates/fft/domain.go.tmpl index 7e1d2e082..88d1fe93f 100644 --- a/field/generator/internal/templates/fft/domain.go.tmpl +++ b/field/generator/internal/templates/fft/domain.go.tmpl @@ -4,6 +4,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "errors" "encoding/binary" @@ -54,17 +55,122 @@ func GeneratorFullMultiplicativeGroup() {{ .FF }}.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. + +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen {{ .FF }}.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) domain.FrMultiplicativeGen = GeneratorFullMultiplicativeGroup() - if opt.shift != nil{ + if opt.shift != nil { domain.FrMultiplicativeGen.Set(opt.shift) } domain.FrMultiplicativeGenInv.Inverse(&domain.FrMultiplicativeGen) diff --git a/field/generator/internal/templates/fft/options.go.tmpl b/field/generator/internal/templates/fft/options.go.tmpl index 87fe89ff4..9213fb8a6 100644 --- a/field/generator/internal/templates/fft/options.go.tmpl +++ b/field/generator/internal/templates/fft/options.go.tmpl @@ -57,6 +57,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *{{ .FF }}.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -75,11 +76,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) diff --git a/field/generator/internal/templates/fft/tests/domain.go.tmpl b/field/generator/internal/templates/fft/tests/domain.go.tmpl index ff775ad36..71d9b4cc0 100644 --- a/field/generator/internal/templates/fft/tests/domain.go.tmpl +++ b/field/generator/internal/templates/fft/tests/domain.go.tmpl @@ -1,8 +1,13 @@ import ( + "bytes" "reflect" + "runtime" "testing" - "bytes" + + "{{ .FieldPackagePath }}" + "github.com/stretchr/testify/require" + ) func TestDomainSerialization(t *testing.T) { @@ -27,4 +32,99 @@ func TestDomainSerialization(t *testing.T) { if !reflect.DeepEqual(domain, &reconstructed) { t.Fatal("Domain.SetBytes(Bytes()) failed") } -} \ No newline at end of file +} + + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := {{ .FF }}.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/field/goldilocks/extensions/e2_test.go b/field/goldilocks/extensions/e2_test.go index c50e416a7..abb4dc21f 100644 --- a/field/goldilocks/extensions/e2_test.go +++ b/field/goldilocks/extensions/e2_test.go @@ -513,3 +513,36 @@ func TestE2Div(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } + +var modulus = fr.Modulus() + +// genFr generates an Fr element +func genFr() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var elmt fr.Element + // SetBigInt will reduce the value modulo the field order + // genParams.Rng is a math/rand.Rand which is not a cryptographically secure + // source of randomness. However, for property based testing, it is desirable + // to have a deterministic generator. + e := bigIntPool.Get().(*big.Int) + e.Rand(genParams.Rng, modulus) + + for i, w := range e.Bits() { + elmt[i] = uint64(w) + } + bigIntPool.Put(e) + + genResult := gopter.NewGenResult(elmt, gopter.NoShrinker) + return genResult + } +} + +// genE2 generates an E2 element +func genE2() gopter.Gen { + return gopter.CombineGens( + genFr(), + genFr(), + ).Map(func(values []interface{}) E2 { + return E2{A0: values[0].(fr.Element), A1: values[1].(fr.Element)} + }) +} diff --git a/field/goldilocks/fft/domain.go b/field/goldilocks/fft/domain.go index 80e9902fe..4a35e1adc 100644 --- a/field/goldilocks/fft/domain.go +++ b/field/goldilocks/fft/domain.go @@ -13,6 +13,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "github.com/consensys/gnark-crypto/field/goldilocks" @@ -61,11 +62,115 @@ func GeneratorFullMultiplicativeGroup() goldilocks.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen goldilocks.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) diff --git a/field/goldilocks/fft/domain_test.go b/field/goldilocks/fft/domain_test.go index 7049120e6..7761c83ce 100644 --- a/field/goldilocks/fft/domain_test.go +++ b/field/goldilocks/fft/domain_test.go @@ -8,7 +8,11 @@ package fft import ( "bytes" "reflect" + "runtime" "testing" + + "github.com/consensys/gnark-crypto/field/goldilocks" + "github.com/stretchr/testify/require" ) func TestDomainSerialization(t *testing.T) { @@ -34,3 +38,97 @@ func TestDomainSerialization(t *testing.T) { t.Fatal("Domain.SetBytes(Bytes()) failed") } } + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := goldilocks.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/field/goldilocks/fft/options.go b/field/goldilocks/fft/options.go index db171a372..1789d5a3c 100644 --- a/field/goldilocks/fft/options.go +++ b/field/goldilocks/fft/options.go @@ -63,6 +63,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *goldilocks.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -81,11 +82,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) diff --git a/field/koalabear/extensions/e2_test.go b/field/koalabear/extensions/e2_test.go index d838cb5bb..acfc6c853 100644 --- a/field/koalabear/extensions/e2_test.go +++ b/field/koalabear/extensions/e2_test.go @@ -513,3 +513,36 @@ func TestE2Div(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } + +var modulus = fr.Modulus() + +// genFr generates an Fr element +func genFr() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var elmt fr.Element + // SetBigInt will reduce the value modulo the field order + // genParams.Rng is a math/rand.Rand which is not a cryptographically secure + // source of randomness. However, for property based testing, it is desirable + // to have a deterministic generator. + e := bigIntPool.Get().(*big.Int) + e.Rand(genParams.Rng, modulus) + + for i, w := range e.Bits() { + elmt[i] = uint32(w) + } + bigIntPool.Put(e) + + genResult := gopter.NewGenResult(elmt, gopter.NoShrinker) + return genResult + } +} + +// genE2 generates an E2 element +func genE2() gopter.Gen { + return gopter.CombineGens( + genFr(), + genFr(), + ).Map(func(values []interface{}) E2 { + return E2{A0: values[0].(fr.Element), A1: values[1].(fr.Element)} + }) +} diff --git a/field/koalabear/fft/domain.go b/field/koalabear/fft/domain.go index 6037f1d6c..dcaf90e8a 100644 --- a/field/koalabear/fft/domain.go +++ b/field/koalabear/fft/domain.go @@ -13,6 +13,7 @@ import ( "math/bits" "runtime" "sync" + "weak" "github.com/consensys/gnark-crypto/field/koalabear" @@ -61,11 +62,115 @@ func GeneratorFullMultiplicativeGroup() koalabear.Element { return res } -// NewDomain returns a subgroup with a power of 2 cardinality -// cardinality >= m -// shift: when specified, it's the element by which the set of root of unity is shifted. +// domainCacheKey is the composite key for the cache. +// It uses a struct with comparable types as the map key. +type domainCacheKey struct { + m uint64 + gen koalabear.Element +} + +var ( + domainCache = make(map[domainCacheKey]weak.Pointer[Domain]) + domainGenLocks = make(map[domainCacheKey]*sync.Mutex) // Per key mutex to avoid multiple concurrent generation of the same domain + keyMapLock sync.Mutex // Ensures exclusive access to domainGenLocks map + domainMapLock sync.Mutex // Ensures exclusive access to domainCache map +) + +// NewDomain returns a subgroup with a power of 2 cardinality >= m. +// +// Parameters: +// - m: minimum cardinality (will be rounded up to next power of 2) +// - opts: configuration options (WithShift, WithCache, WithoutPrecompute, etc.) +// +// The domain can be cached when both withCache and withPrecompute are enabled. +// Cached domains are automatically cleaned up when no longer in use. func NewDomain(m uint64, opts ...DomainOption) *Domain { opt := domainOptions(opts...) + + // Skip caching if disabled or precomputation is off + if !opt.withCache || !opt.withPrecompute { + return createDomain(m, opt) + } + + // Compute the cache key. + key := domainCacheKey{m: m} + if opt.shift != nil { + key.gen.Set(opt.shift) + } else { + key.gen = GeneratorFullMultiplicativeGroup() // Default generator + } + + // Lets ensure that only one goroutine is generating a domain for this + // specific key. We acquire it already here to ensure if there is a existing + // goroutine generating a domain for this key, we wait for it to finish and + // then we can just return the cached domain. + keyMapLock.Lock() + keyLock := domainGenLocks[key] + if keyLock == nil { + keyLock = new(sync.Mutex) + domainGenLocks[key] = keyLock + } + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + // Check cache first. But for the cache, we need to lock the cache map (we + // currently only hold the per-key lock, not global cache lock). + domainMapLock.Lock() + // we don't defer it because we want to release it while creating the + // domain. And domain creation can panic, leading to double unlock which + // hides the original panic. + if weakDomain, exists := domainCache[key]; exists { + if domain := weakDomain.Value(); domain != nil { + domainMapLock.Unlock() + return domain + } + } + // Lets release the global cache lock while we do this so that other keys + // can be added to cache. + domainMapLock.Unlock() + + // Create a new domain (expensive operation, but only blocks same key). + domain := createDomain(m, opt) + + // Store in cache with cleanup + weakDomain := weak.Make(domain) + domainMapLock.Lock() + domainCache[key] = weakDomain + domainMapLock.Unlock() + + // Add cleanup to remove from cache when domain is garbage collected + runtime.AddCleanup(domain, func(key domainCacheKey) { + // cleanup *may* be called concurrently, but could be sequential. We run + // it in a separate goroutine to avoid block other cleanups being run if + // this cleanup is being run on the same key which is being generated + // (thus lock being held). + go func() { + keyMapLock.Lock() + defer keyMapLock.Unlock() + if keyLock, ok := domainGenLocks[key]; ok { + keyLock.Lock() + defer keyLock.Unlock() + // We can now safely delete from both maps. But we only do if + // the cached weak pointer is the same one we created. Otherwise + // this means this cleanup is running after a new domain was + // already cached (double cleanup). + + // We also want to hold both per-key and cache lock to avoid the + // maps being out of sync + domainMapLock.Lock() + defer domainMapLock.Unlock() + if cacheWeakDomain := domainCache[key]; cacheWeakDomain == weakDomain { + delete(domainCache, key) + delete(domainGenLocks, key) + } + } + }() + }, key) + return domain +} + +func createDomain(m uint64, opt domainConfig) *Domain { domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) diff --git a/field/koalabear/fft/domain_test.go b/field/koalabear/fft/domain_test.go index 7049120e6..44fcc4a62 100644 --- a/field/koalabear/fft/domain_test.go +++ b/field/koalabear/fft/domain_test.go @@ -8,7 +8,11 @@ package fft import ( "bytes" "reflect" + "runtime" "testing" + + "github.com/consensys/gnark-crypto/field/koalabear" + "github.com/stretchr/testify/require" ) func TestDomainSerialization(t *testing.T) { @@ -34,3 +38,97 @@ func TestDomainSerialization(t *testing.T) { t.Fatal("Domain.SetBytes(Bytes()) failed") } } + +func TestNewDomainCache(t *testing.T) { + t.Run("CacheWithoutShift", func(t *testing.T) { + key1 := domainCacheKey{ + m: 256, + gen: GeneratorFullMultiplicativeGroup(), + } + require.Nil(t, getCachedDomain(key1), "cache should be empty initially") + domain1 := NewDomain(256, WithCache()) + expected1 := NewDomain(256) + require.Equal(t, domain1, expected1, "domain1 should equal expected1") + require.Same(t, domain1, getCachedDomain(key1), "domain1 should be stored in cache") + }) + + t.Run("CacheWithShift", func(t *testing.T) { + shift := koalabear.NewElement(5) + key2 := domainCacheKey{ + m: 512, + gen: shift, + } + require.Nil(t, getCachedDomain(key2), "cache should be empty initially") + domain2 := NewDomain(512, WithShift(shift), WithCache()) + expected2 := NewDomain(512, WithShift(shift)) + require.Equal(t, domain2, expected2, "domain2 should equal expected2") + require.Same(t, domain2, getCachedDomain(key2), "domain2 should be stored in cache") + }) +} + +func TestGCBehavior(t *testing.T) { + t.Run("DomainKeptAliveNotCollected", func(t *testing.T) { + domain := NewDomain(1<<20, WithCache()) + require.NotNil(t, domain, "domain should not be empty") + key := domainCacheKey{ + m: 1 << 20, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + runtime.GC() + require.NotNil(t, getCachedDomain(key), "domain should still be cached") + runtime.KeepAlive(domain) + }) + + t.Run("UnreferencedDomainCollected", func(t *testing.T) { + domain := NewDomain(1<<19, WithCache()) + key := domainCacheKey{ + m: 1 << 19, + gen: GeneratorFullMultiplicativeGroup(), + } + require.NotNil(t, getCachedDomain(key), "domain should be cached") + // Last use of domain + _ = domain.Cardinality + runtime.GC() + require.Nil(t, getCachedDomain(key), "unreferenced domain should be collected and removed from cache") + }) +} + +func BenchmarkNewDomainCache(b *testing.B) { + b.Run("WithCache", func(b *testing.B) { + // lets first initialize in cache already + cached := NewDomain(1<<20, WithCache()) + for b.Loop() { + _ = NewDomain(1<<20, WithCache()) + } + runtime.KeepAlive(cached) // prevent cached from being GCed + }) + + b.Run("WithoutCache", func(b *testing.B) { + for b.Loop() { + _ = NewDomain(1 << 20) + } + }) +} + +// Helper functions +func getCachedDomain(key domainCacheKey) *Domain { + keyMapLock.Lock() + keyLock, exists := domainGenLocks[key] + if !exists { + keyMapLock.Unlock() + return nil + } + + // Acquire key lock while holding global lock to prevent races + keyLock.Lock() + defer keyLock.Unlock() + keyMapLock.Unlock() + + domainMapLock.Lock() + defer domainMapLock.Unlock() + if weak, exists := domainCache[key]; exists { + return weak.Value() + } + return nil +} diff --git a/field/koalabear/fft/options.go b/field/koalabear/fft/options.go index 7c4ba3ffb..b3d8a95c1 100644 --- a/field/koalabear/fft/options.go +++ b/field/koalabear/fft/options.go @@ -63,6 +63,7 @@ type DomainOption func(*domainConfig) type domainConfig struct { shift *koalabear.Element withPrecompute bool + withCache bool } // WithShift sets the FrMultiplicativeGen of the domain. @@ -81,11 +82,19 @@ func WithoutPrecompute() DomainOption { } } +// WithCache enables domain caching +func WithCache() DomainOption { + return func(opt *domainConfig) { + opt.withCache = true + } +} + // default options func domainOptions(opts ...DomainOption) domainConfig { // apply options opt := domainConfig{ withPrecompute: true, + withCache: false, } for _, option := range opts { option(&opt) From abd34dcc60051487e9bd71180444d3a48b20aad2 Mon Sep 17 00:00:00 2001 From: feltroid Prime <96737978+feltroidprime@users.noreply.github.com> Date: Fri, 2 May 2025 13:22:55 +0200 Subject: [PATCH 4/4] fix: Eisenstein Half-GCD convergence (#680) Co-authored-by: Ivo Kubjas --- field/eisenstein/eisenstein.go | 71 ++++++++++++++++++++++++++--- field/eisenstein/eisenstein_test.go | 61 +++++++++++++++++++++++++ 2 files changed, 125 insertions(+), 7 deletions(-) diff --git a/field/eisenstein/eisenstein.go b/field/eisenstein/eisenstein.go index e80915413..033ed902b 100644 --- a/field/eisenstein/eisenstein.go +++ b/field/eisenstein/eisenstein.go @@ -9,6 +9,27 @@ type ComplexNumber struct { A0, A1 *big.Int } +// ────────────────────────────────────────────────────────────────────────────── +// helpers – hex-lattice geometry & symmetric rounding +// ────────────────────────────────────────────────────────────────────────────── + +// six axial directions of the hexagonal lattice +var neighbours = [][2]int64{ + {1, 0}, {0, 1}, {-1, 1}, {-1, 0}, {0, -1}, {1, -1}, +} + +// roundNearest returns ⌊(z + d/2) / d⌋ for *any* sign of z, d>0 +func roundNearest(z, d *big.Int) *big.Int { + half := new(big.Int).Rsh(d, 1) // d / 2 + if z.Sign() >= 0 { + return new(big.Int).Div(new(big.Int).Add(z, half), d) + } + tmp := new(big.Int).Neg(z) + tmp.Add(tmp, half) + tmp.Div(tmp, d) + return tmp.Neg(tmp) +} + func (z *ComplexNumber) init() { if z.A0 == nil { z.A0 = new(big.Int) @@ -124,19 +145,55 @@ func (z *ComplexNumber) Norm() *big.Int { return norm } -// QuoRem sets z to the quotient of x and y, r to the remainder, and returns z and r. +// QuoRem sets z to the Euclidean quotient of x / y, r to the remainder, +// and guarantees ‖r‖ < ‖y‖ (true Euclidean division in ℤ[ω]). func (z *ComplexNumber) QuoRem(x, y, r *ComplexNumber) (*ComplexNumber, *ComplexNumber) { - norm := y.Norm() - if norm.Cmp(big.NewInt(0)) == 0 { + + norm := y.Norm() // > 0 (Eisenstein norm is always non-neg) + if norm.Sign() == 0 { panic("division by zero") } - z.Conjugate(y) - z.Mul(x, z) - z.A0.Div(z.A0, norm) - z.A1.Div(z.A1, norm) + + // num = x * ȳ (ȳ computed in a fresh variable → y unchanged) + var yConj, num ComplexNumber + yConj.Conjugate(y) + num.Mul(x, &yConj) + + // first guess by *symmetric* rounding of both coordinates + q0 := roundNearest(num.A0, norm) + q1 := roundNearest(num.A1, norm) + z.A0, z.A1 = q0, q1 + + // r = x – q*y r.Mul(y, z) r.Sub(x, r) + // If Euclidean inequality already holds we're done. + // Otherwise walk ≤2 unit steps in the hex lattice until N(r) < N(y). + if r.Norm().Cmp(norm) >= 0 { + bestQ0, bestQ1 := new(big.Int).Set(z.A0), new(big.Int).Set(z.A1) + bestR := new(ComplexNumber).Set(r) + bestN2 := bestR.Norm() + + for _, dir := range neighbours { + candQ0 := new(big.Int).Add(z.A0, big.NewInt(dir[0])) + candQ1 := new(big.Int).Add(z.A1, big.NewInt(dir[1])) + var candQ ComplexNumber + candQ.A0, candQ.A1 = candQ0, candQ1 + + var candR ComplexNumber + candR.Mul(y, &candQ) + candR.Sub(x, &candR) + + if candR.Norm().Cmp(bestN2) < 0 { + bestQ0, bestQ1 = candQ0, candQ1 + bestR.Set(&candR) + bestN2 = bestR.Norm() + } + } + z.A0, z.A1 = bestQ0, bestQ1 + r.Set(bestR) // update remainder and retry; Euclidean property ⇒ ≤ 2 loops + } return z, r } diff --git a/field/eisenstein/eisenstein_test.go b/field/eisenstein/eisenstein_test.go index 6aff795f9..0d4620444 100644 --- a/field/eisenstein/eisenstein_test.go +++ b/field/eisenstein/eisenstein_test.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "math/big" "testing" + "time" "github.com/leanovate/gopter" "github.com/leanovate/gopter/prop" @@ -240,6 +241,66 @@ func TestEisensteinHalfGCD(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } +func TestEisensteinQuoRem(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + genE := GenComplexNumber(boundSize) + + properties.Property("QuoRem should be correct", prop.ForAll( + func(a, b *ComplexNumber) bool { + var z, rem ComplexNumber + z.QuoRem(a, b, &rem) + var res ComplexNumber + res.Mul(b, &z) + res.Add(&res, &rem) + return res.Equal(a) + }, + genE, + genE, + )) + + properties.Property("QuoRem remainder should be smaller than divisor", prop.ForAll( + func(a, b *ComplexNumber) bool { + var z, rem ComplexNumber + z.QuoRem(a, b, &rem) + return rem.Norm().Cmp(b.Norm()) == -1 + }, + genE, + genE, + )) +} + +func TestRegressionHalfGCD1483(t *testing.T) { + // This test is a regression test for issue #1483 in gnark + a0, _ := new(big.Int).SetString("64502973549206556628585045361533709077", 10) + a1, _ := new(big.Int).SetString("-303414439467246543595250775667605759171", 10) + c0, _ := new(big.Int).SetString("-432420386565659656852420866390673177323", 10) + c1, _ := new(big.Int).SetString("238911465918039986966665730306072050094", 10) + a := ComplexNumber{A0: a0, A1: a1} + c := ComplexNumber{A0: c0, A1: c1} + + ticker := time.NewTimer(time.Second * 3) + doneCh := make(chan struct{}) + go func() { + HalfGCD(&a, &c) + close(doneCh) + }() + + select { + case <-ticker.C: + t.Error("HalfGCD took too long to compute") + case <-doneCh: + // Test passed + } +} + // GenNumber generates a random integer func GenNumber(boundSize int64) gopter.Gen { return func(genParams *gopter.GenParameters) *gopter.GenResult {