Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 84 additions & 11 deletions dtlstransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
package webrtc

import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -47,7 +49,7 @@
onStateChangeHandler func(DTLSTransportState)
internalOnCloseHandler func()

conn *dtls.Conn
conn DTLSConn

srtpSession, srtcpSession atomic.Value
srtpEndpoint, srtcpEndpoint *mux.Endpoint
Expand All @@ -60,6 +62,50 @@
log logging.LeveledLogger
}

// DTLSConn wraps the DTLS connection used by DTLSTransport.
// It can be injected via SettingEngine to allow replacement at runtime.
type DTLSConn interface {
net.Conn
Handshake() error
HandshakeContext(ctx context.Context) error
SelectedSRTPProtectionProfile() (dtls.SRTPProtectionProfile, bool)
SRTPKeyingMaterialExporter() (srtp.KeyingMaterialExporter, error)
}

// dtlsConnFactory is a factory function for creating a DTLS connection.
type dtlsConnFactory func(role DTLSRole, conn net.PacketConn, remoteAddr net.Addr, config *dtls.Config) (DTLSConn, error)

Check failure on line 76 in dtlstransport.go

View workflow job for this annotation

GitHub Actions / lint / Go

type dtlsConnFactory is unused (unused)

// pionDTLSConn wraps a *dtls.Conn to implement the DTLSConn interface.
// The only thing we need to implement is the SRTPKeyingMaterialExporter method,
// all other methods are thunks to the underlying *dtls.Conn.
type pionDTLSConn struct {
conn *dtls.Conn
}

func (c *pionDTLSConn) Close() error { return c.conn.Close() }
func (c *pionDTLSConn) Read(b []byte) (int, error) { return c.conn.Read(b) }
func (c *pionDTLSConn) Write(b []byte) (int, error) { return c.conn.Write(b) }
func (c *pionDTLSConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
func (c *pionDTLSConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
func (c *pionDTLSConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
func (c *pionDTLSConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
func (c *pionDTLSConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
func (c *pionDTLSConn) Handshake() error { return c.conn.Handshake() }
func (c *pionDTLSConn) HandshakeContext(ctx context.Context) error {
return c.conn.HandshakeContext(ctx)
}

Check failure on line 96 in dtlstransport.go

View workflow job for this annotation

GitHub Actions / lint / Go

File is not properly formatted (gofumpt)
func (c *pionDTLSConn) SelectedSRTPProtectionProfile() (dtls.SRTPProtectionProfile, bool) {
return c.conn.SelectedSRTPProtectionProfile()
}
func (c *pionDTLSConn) SRTPKeyingMaterialExporter() (srtp.KeyingMaterialExporter, error) {
connState, ok := c.conn.ConnectionState()
if !ok {
return nil, fmt.Errorf("failed to get DTLS ConnectionState")

Check failure on line 103 in dtlstransport.go

View workflow job for this annotation

GitHub Actions / lint / Go

do not define dynamic errors, use wrapped static errors instead: "fmt.Errorf(\"failed to get DTLS ConnectionState\")" (err113)
}

return &connState, nil
}

type simulcastStreamPair struct {
srtp *srtp.ReadStreamSRTP
srtcp *srtp.ReadStreamSRTCP
Expand Down Expand Up @@ -226,13 +272,13 @@
)
}

connState, ok := t.conn.ConnectionState()
if !ok {
exporter, err := t.conn.SRTPKeyingMaterialExporter()
if err != nil {
// nolint
return fmt.Errorf("%w: Failed to get DTLS ConnectionState", errDtlsKeyExtractionFailed)
return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err)
}

err := srtpConfig.ExtractSessionKeysFromDTLS(&connState, t.role() == DTLSRoleClient)
err = srtpConfig.ExtractSessionKeysFromDTLS(exporter, t.role() == DTLSRoleClient)
if err != nil {
// nolint
return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err)
Expand Down Expand Up @@ -300,6 +346,38 @@
return defaultDtlsRoleAnswer
}

func (t *DTLSTransport) createDTLSConn(
role DTLSRole,
conn net.PacketConn,
remoteAddr net.Addr,
config *dtls.Config,
) (DTLSConn, error) {
factory := t.api.settingEngine.dtls.connFactory
if factory == nil {
factory = defaultDTLSConnFactory
}
return factory(role, conn, remoteAddr, config)

Check failure on line 359 in dtlstransport.go

View workflow job for this annotation

GitHub Actions / lint / Go

return with no blank line before (nlreturn)
}

func defaultDTLSConnFactory(
role DTLSRole,
conn net.PacketConn,
remoteAddr net.Addr,
config *dtls.Config,
) (DTLSConn, error) {
var dtlsConn *dtls.Conn
var err error
if role == DTLSRoleClient {
dtlsConn, err = dtls.Client(conn, remoteAddr, config)
} else {
dtlsConn, err = dtls.Server(conn, remoteAddr, config)
}
if err != nil {
return nil, err
}
return &pionDTLSConn{conn: dtlsConn}, nil

Check failure on line 378 in dtlstransport.go

View workflow job for this annotation

GitHub Actions / lint / Go

return with no blank line before (nlreturn)
}

// Start DTLS transport negotiation with the parameters of the remote DTLS transport.
func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error { //nolint:gocognit,cyclop
// Take lock and prepare connection, we must not hold the lock
Expand Down Expand Up @@ -345,7 +423,6 @@
}, nil
}

var dtlsConn *dtls.Conn
dtlsEndpoint := t.iceTransport.newEndpoint(mux.MatchDTLS)
dtlsEndpoint.SetOnClose(t.internalOnCloseHandler)
role, dtlsConfig, err := prepareTransport()
Expand Down Expand Up @@ -394,11 +471,7 @@

// Connect as DTLS Client/Server, function is blocking and we
// must not hold the DTLSTransport lock
if role == DTLSRoleClient {
dtlsConn, err = dtls.Client(dtlsEndpoint, dtlsEndpoint.RemoteAddr(), dtlsConfig)
} else {
dtlsConn, err = dtls.Server(dtlsEndpoint, dtlsEndpoint.RemoteAddr(), dtlsConfig)
}
dtlsConn, err := t.createDTLSConn(role, dtlsEndpoint, dtlsEndpoint.RemoteAddr(), dtlsConfig)

if err == nil {
if t.api.settingEngine.dtls.connectContextMaker != nil {
Expand Down
9 changes: 9 additions & 0 deletions settingengine.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
retransmissionInterval time.Duration
ellipticCurves []dtlsElliptic.Curve
connectContextMaker func() (context.Context, func())
connFactory func(role DTLSRole, conn net.PacketConn, remoteAddr net.Addr, config *dtls.Config) (DTLSConn, error)

Check failure on line 75 in settingengine.go

View workflow job for this annotation

GitHub Actions / lint / Go

The line is 132 characters long, which exceeds the maximum of 120 characters. (lll)
extendedMasterSecret dtls.ExtendedMasterSecretType
clientAuth *dtls.ClientAuthType
clientCAs *x509.CertPool
Expand Down Expand Up @@ -551,6 +552,14 @@
e.dtls.connectContextMaker = connectContextMaker
}

// SetDTLSConnFactory overrides the DTLS connection creation for DTLSTransport.
// If nil, the default pion/dtls Client/Server constructors are used.
func (e *SettingEngine) SetDTLSConnFactory(
factory func(role DTLSRole, conn net.PacketConn, remoteAddr net.Addr, config *dtls.Config) (DTLSConn, error),
) {
e.dtls.connFactory = factory
}

// SetDTLSExtendedMasterSecret sets the extended master secret type for DTLS.
func (e *SettingEngine) SetDTLSExtendedMasterSecret(extendedMasterSecret dtls.ExtendedMasterSecretType) {
e.dtls.extendedMasterSecret = extendedMasterSecret
Expand Down
Loading