Skip to content

Commit edc7ec9

Browse files
author
Moses Narrow
committed
Merge remote-tracking branch 'upstream/develop' into develop
2 parents d9cd661 + 7965982 commit edc7ec9

File tree

5 files changed

+133
-4
lines changed

5 files changed

+133
-4
lines changed

pkg/dmsg/server.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,22 @@ func (s *Server) Ready() <-chan struct{} {
207207
}
208208

209209
func (s *Server) handleSession(conn net.Conn) {
210+
defer func() {
211+
if r := recover(); r != nil {
212+
s.log.WithField("panic", r).
213+
WithField("remote_tcp", conn.RemoteAddr()).
214+
Error("Recovered from panic in handleSession, connection will be closed")
215+
if err := conn.Close(); err != nil {
216+
s.log.WithError(err).Warn("Failed to close connection after panic recovery")
217+
}
218+
}
219+
}()
220+
210221
log := s.log.WithField("remote_tcp", conn.RemoteAddr())
211222

212223
dSes, err := makeServerSession(s.m, &s.EntityCommon, conn)
213224
if err != nil {
225+
log.WithError(err).Warn("Failed to create server session")
214226
if err := conn.Close(); err != nil {
215227
log.WithError(err).Warn("On handleSession() failure, close connection resulted in error.")
216228
}

pkg/dmsg/server_session.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ func (ss *ServerSession) Serve() {
6262
log.Info("Initiating stream.")
6363

6464
go func(sStr *smux.Stream) {
65+
defer func() {
66+
if r := recover(); r != nil {
67+
log.WithField("panic", r).Error("Recovered from panic in serveStream")
68+
}
69+
}()
6570
err := ss.serveStream(log, sStr, ss.sm.addr)
6671
log.WithError(err).Info("Stopped stream.")
6772
}(sStr)
@@ -83,6 +88,11 @@ func (ss *ServerSession) Serve() {
8388
log.Info("Initiating stream.")
8489

8590
go func(yStr *yamux.Stream) {
91+
defer func() {
92+
if r := recover(); r != nil {
93+
log.WithField("panic", r).Error("Recovered from panic in serveStream")
94+
}
95+
}()
8696
err := ss.serveStream(log, yStr, ss.sm.addr)
8797
log.WithError(err).Info("Stopped stream.")
8898
}(yStr)

pkg/dmsg/stream_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,3 +307,90 @@ func GenKeyPair(t *testing.T, seed string) (cipher.PubKey, cipher.SecKey) {
307307
require.NoError(t, err)
308308
return pk, sk
309309
}
310+
311+
// TestInvalidPublicKeyNoPanic tests that the server doesn't crash when receiving
312+
// a connection with an invalid public key during the noise handshake.
313+
// This is a regression test for a 2+ year old bug where invalid public keys
314+
// would cause the server to panic and crash.
315+
func TestInvalidPublicKeyNoPanic(t *testing.T) {
316+
// Prepare mock discovery.
317+
dc := disc.NewMock(0)
318+
const maxSessions = 10
319+
320+
// Prepare dmsg server.
321+
pkSrv, skSrv := GenKeyPair(t, "server")
322+
srvConf := &ServerConfig{
323+
MaxSessions: maxSessions,
324+
UpdateInterval: 0,
325+
}
326+
srv := NewServer(pkSrv, skSrv, dc, srvConf, nil)
327+
srv.SetLogger(logging.MustGetLogger("server"))
328+
lisSrv, err := net.Listen("tcp", "")
329+
require.NoError(t, err)
330+
331+
// Serve dmsg server.
332+
chSrv := make(chan error, 1)
333+
go func() { chSrv <- srv.Serve(lisSrv, "") }() //nolint:errcheck
334+
335+
// Give server time to start
336+
time.Sleep(500 * time.Millisecond)
337+
338+
// Attempt to send a handshake with invalid public key data
339+
// This simulates a malicious or buggy client
340+
t.Run("invalid_pubkey_handshake", func(t *testing.T) {
341+
conn, err := net.Dial("tcp", lisSrv.Addr().String())
342+
require.NoError(t, err)
343+
defer func() { _ = conn.Close() }() //nolint:errcheck
344+
345+
// Send invalid noise handshake data (contains invalid public key)
346+
// In a real noise handshake, the public key would be embedded in the message
347+
// We send malformed data that will trigger invalid public key error
348+
invalidData := make([]byte, 100)
349+
// Write some invalid data that looks like a handshake but has invalid key
350+
copy(invalidData, []byte{0x00, 0x32}) // frame length prefix (50 bytes)
351+
// Rest is invalid/random data that will fail public key validation
352+
for i := 2; i < len(invalidData); i++ {
353+
invalidData[i] = byte(i) // deterministic but invalid
354+
}
355+
356+
_, err = conn.Write(invalidData)
357+
// Write may succeed, but the server should handle the invalid data gracefully
358+
if err != nil {
359+
t.Logf("Write failed (expected): %v", err)
360+
}
361+
362+
// Give server time to process the invalid handshake
363+
time.Sleep(500 * time.Millisecond)
364+
365+
// Read to see if connection was closed (expected behavior)
366+
buf := make([]byte, 10)
367+
_ = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) //nolint:errcheck
368+
_, _ = conn.Read(buf) //nolint:errcheck
369+
// We expect the connection to be closed or timeout
370+
// The important thing is the server didn't crash
371+
})
372+
373+
// Verify server is still running and can accept valid connections
374+
t.Run("valid_connection_after_invalid", func(t *testing.T) {
375+
// Prepare and serve a valid dmsg client
376+
pkA, skA := GenKeyPair(t, "client A")
377+
clientA := NewClient(pkA, skA, dc, DefaultConfig())
378+
clientA.SetLogger(logging.MustGetLogger("client_A"))
379+
go clientA.Serve(context.Background())
380+
381+
// Wait for client to register
382+
time.Sleep(time.Second * 2)
383+
384+
// Attempt to use the client - if server crashed, this will fail
385+
lis, err := clientA.Listen(8081)
386+
require.NoError(t, err, "Server should still be running and accept valid connections")
387+
388+
// Clean up
389+
require.NoError(t, lis.Close())
390+
require.NoError(t, clientA.Close())
391+
})
392+
393+
// Closing logic - server should still be healthy
394+
require.NoError(t, srv.Close())
395+
require.NoError(t, <-chSrv)
396+
}

pkg/noise/dh.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,24 @@ func (Secp256k1) GenerateKeypair(_ io.Reader) (noise.DHKey, error) {
2222

2323
// DH helps to implement `noise.DHFunc`.
2424
func (Secp256k1) DH(sk, pk []byte) []byte {
25-
return append(
26-
cipher.MustECDH(cipher.MustNewPubKey(pk), cipher.MustNewSecKey(sk)),
27-
byte(0))
25+
// Use non-panic versions to handle invalid keys gracefully
26+
pubKey, err := cipher.NewPubKey(pk)
27+
if err != nil {
28+
// Return empty key on error to prevent panic
29+
// The handshake will fail with this invalid key
30+
return make([]byte, 33)
31+
}
32+
secKey, err := cipher.NewSecKey(sk)
33+
if err != nil {
34+
// Return empty key on error to prevent panic
35+
return make([]byte, 33)
36+
}
37+
ecdh, err := cipher.ECDH(pubKey, secKey)
38+
if err != nil {
39+
// Return empty key on error to prevent panic
40+
return make([]byte, 33)
41+
}
42+
return append(ecdh, byte(0))
2843
}
2944

3045
// DHLen helps to implement `noise.DHFunc`.

pkg/noise/read_writer.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,17 @@ func (rw *ReadWriter) Write(p []byte) (n int, err error) {
176176
func (rw *ReadWriter) Handshake(hsTimeout time.Duration) error {
177177
errCh := make(chan error, 1)
178178
go func() {
179+
defer func() {
180+
if r := recover(); r != nil {
181+
errCh <- fmt.Errorf("handshake panic: %v", r)
182+
}
183+
close(errCh)
184+
}()
179185
if rw.ns.init {
180186
errCh <- InitiatorHandshake(rw.ns, rw.rawInput, rw.origin)
181187
} else {
182188
errCh <- ResponderHandshake(rw.ns, rw.rawInput, rw.origin)
183189
}
184-
close(errCh)
185190
}()
186191
select {
187192
case err := <-errCh:

0 commit comments

Comments
 (0)