diff --git a/CHANGELOG.md b/CHANGELOG.md index 2178ad8749..689ba2a862 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -76,6 +76,7 @@ sequentially through each stable release, selecting the latest patch version ava - API clients should use the `Tags` field instead of `ValidTags` - The `headscale nodes list` CLI command now always shows a Tags column and the `--tags` flag has been removed - **PreAuthKey CLI**: Commands now use ID-based operations instead of user+key combinations [#2992](https://github.com/juanfont/headscale/pull/2992) + - `headscale preauthkeys create` no longer requires `--user` flag (optional for tracking creation) - `headscale preauthkeys list` lists all keys (no longer filtered by user) - `headscale preauthkeys expire --id ` replaces `--user ` @@ -120,6 +121,7 @@ sequentially through each stable release, selecting the latest patch version ava - When `false`, unverified emails are allowed for OIDC authentication and the email address is stored in the user profile regardless of its verification state. - **SSH Policy**: Wildcard (`*`) is no longer supported as an SSH destination [#3009](https://github.com/juanfont/headscale/issues/3009) + - Use `autogroup:member` for user-owned devices - Use `autogroup:tagged` for tagged devices - Use specific tags (e.g., `tag:server`) for targeted access @@ -139,6 +141,7 @@ sequentially through each stable release, selecting the latest patch version ava - **SSH Policy**: SSH source/destination validation now enforces Tailscale's security model [#3010](https://github.com/juanfont/headscale/issues/3010) Per [Tailscale SSH documentation](https://tailscale.com/kb/1193/tailscale-ssh), the following rules are now enforced: + 1. **Tags cannot SSH to user-owned devices**: SSH rules with `tag:*` or `autogroup:tagged` as source cannot have username destinations (e.g., `alice@`) or `autogroup:member`/`autogroup:self` as destination 2. **Username destinations require same-user source**: If destination is a specific username (e.g., `alice@`), the source must be that exact same user only. Use `autogroup:self` for same-user SSH access instead @@ -186,6 +189,7 @@ sequentially through each stable release, selecting the latest patch version ava - Add `taildrop.enabled` configuration option to enable/disable Taildrop file sharing [#2955](https://github.com/juanfont/headscale/pull/2955) - Allow disabling the metrics server by setting empty `metrics_listen_addr` [#2914](https://github.com/juanfont/headscale/pull/2914) - Log ACME/autocert errors for easier debugging [#2933](https://github.com/juanfont/headscale/pull/2933) +- Certificates now reload on SIGHUP signal [#3041](https://github.com/juanfont/headscale/pull/3041) - Improve CLI list output formatting [#2951](https://github.com/juanfont/headscale/pull/2951) - Use Debian 13 distroless base images for containers [#2944](https://github.com/juanfont/headscale/pull/2944) - Fix ACL policy not applied to new OIDC nodes until client restart [#2890](https://github.com/juanfont/headscale/pull/2890) diff --git a/hscontrol/app.go b/hscontrol/app.go index abd29a45a5..beda208385 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -102,6 +102,11 @@ type Headscale struct { mapBatcher mapper.Batcher clientStreamsOpen sync.WaitGroup + + // TLS certificate for manual TLS configuration (non-ACME). + // Protected by tlsCertMu for concurrent access during SIGHUP reload. + tlsCertMu sync.RWMutex + tlsCert *tls.Certificate } var ( @@ -823,19 +828,28 @@ func (h *Headscale) Serve() error { case syscall.SIGHUP: log.Info(). Str("signal", sig.String()). - Msg("Received SIGHUP, reloading ACL policy") + Msg("Received SIGHUP, reloading TLS certificate") - if h.cfg.Policy.IsEmpty() { - continue + // Reload TLS certificate if using manual TLS (not ACME/Let's Encrypt) + if h.cfg.TLS.CertPath != "" && h.cfg.TLS.LetsEncrypt.Hostname == "" { + if err := h.reloadTLSCertificate(); err != nil { + log.Error().Err(err).Msg("reloading TLS certificate") + } } - changes, err := h.state.ReloadPolicy() - if err != nil { - log.Error().Err(err).Msgf("reloading policy") - continue - } + log.Info(). + Str("signal", sig.String()). + Msg("Received SIGHUP, reloading ACL policy") - h.Change(changes...) + // Reload ACL policy + if !h.cfg.Policy.IsEmpty() { + changes, err := h.state.ReloadPolicy() + if err != nil { + log.Error().Err(err).Msg("reloading ACL policy") + } else { + h.Change(changes...) + } + } default: info := func(msg string) { log.Info().Msg(msg) } @@ -995,15 +1009,47 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { } tlsConfig := &tls.Config{ - NextProtos: []string{"http/1.1"}, - Certificates: make([]tls.Certificate, 1), - MinVersion: tls.VersionTLS12, + NextProtos: []string{"http/1.1"}, + MinVersion: tls.VersionTLS12, } - tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(h.cfg.TLS.CertPath, h.cfg.TLS.KeyPath) + if err := h.reloadTLSCertificate(); err != nil { + return nil, err + } - return tlsConfig, err + tlsConfig.GetCertificate = h.getTLSCertificate + + return tlsConfig, nil + } +} + +// reloadTLSCertificate loads or reloads the TLS certificate from disk. +// This is called on startup and on SIGHUP for certificate rotation. +func (h *Headscale) reloadTLSCertificate() error { + cert, err := tls.LoadX509KeyPair(h.cfg.TLS.CertPath, h.cfg.TLS.KeyPath) + if err != nil { + return fmt.Errorf("loading TLS certificate: %w", err) } + + h.tlsCertMu.Lock() + h.tlsCert = &cert + h.tlsCertMu.Unlock() + + log.Info(). + Str("cert_path", h.cfg.TLS.CertPath). + Str("key_path", h.cfg.TLS.KeyPath). + Msg("TLS certificate loaded") + + return nil +} + +// getTLSCertificate returns the current TLS certificate. +// It implements the tls.Config.GetCertificate callback signature. +func (h *Headscale) getTLSCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + h.tlsCertMu.RLock() + defer h.tlsCertMu.RUnlock() + + return h.tlsCert, nil } func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { diff --git a/hscontrol/tls_test.go b/hscontrol/tls_test.go new file mode 100644 index 0000000000..0cda077898 --- /dev/null +++ b/hscontrol/tls_test.go @@ -0,0 +1,317 @@ +package hscontrol + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// createTestCertificate generates a self-signed certificate and private key for testing. +// Returns cert PEM bytes, key PEM bytes, and any error. +func createTestCertificate(hostname string) ([]byte, []byte, error) { + // Generate a private key + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, err + } + + // Create certificate template + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: hostname, + Organization: []string{"Headscale Test"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(1 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{hostname}, + } + + // Self-sign the certificate + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &privKey.PublicKey, privKey) + if err != nil { + return nil, nil, err + } + + // PEM encode the certificate + certPEM := new(bytes.Buffer) + err = pem.Encode(certPEM, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + if err != nil { + return nil, nil, err + } + + // PEM encode the private key + keyPEM := new(bytes.Buffer) + err = pem.Encode(keyPEM, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privKey)}) + if err != nil { + return nil, nil, err + } + + return certPEM.Bytes(), keyPEM.Bytes(), nil +} + +// writeCertFiles writes certificate and key PEM data to files in the given directory. +func writeCertFiles(t *testing.T, dir string, certPEM, keyPEM []byte) (certPath, keyPath string) { + t.Helper() + + certPath = filepath.Join(dir, "cert.pem") + keyPath = filepath.Join(dir, "key.pem") + + err := os.WriteFile(certPath, certPEM, 0o600) + require.NoError(t, err) + + err = os.WriteFile(keyPath, keyPEM, 0o600) + require.NoError(t, err) + + return certPath, keyPath +} + +func TestReloadTLSCertificate_InitialLoad(t *testing.T) { + tmpDir := t.TempDir() + + // Create test certificate + certPEM, keyPEM, err := createTestCertificate("test.example.com") + require.NoError(t, err) + + certPath, keyPath := writeCertFiles(t, tmpDir, certPEM, keyPEM) + + // Create minimal Headscale instance with TLS config + h := &Headscale{ + cfg: &types.Config{ + TLS: types.TLSConfig{ + CertPath: certPath, + KeyPath: keyPath, + }, + }, + } + + // Test initial certificate load + err = h.reloadTLSCertificate() + require.NoError(t, err) + + // Verify certificate was loaded + cert, err := h.getTLSCertificate(nil) + require.NoError(t, err) + require.NotNil(t, cert) + + // Verify certificate content matches what we wrote + expectedCert, err := tls.LoadX509KeyPair(certPath, keyPath) + require.NoError(t, err) + assert.Equal(t, expectedCert.Certificate, cert.Certificate) +} + +func TestReloadTLSCertificate_ReloadUpdatedCert(t *testing.T) { + tmpDir := t.TempDir() + + // Create initial certificate + certPEM1, keyPEM1, err := createTestCertificate("initial.example.com") + require.NoError(t, err) + + certPath, keyPath := writeCertFiles(t, tmpDir, certPEM1, keyPEM1) + + h := &Headscale{ + cfg: &types.Config{ + TLS: types.TLSConfig{ + CertPath: certPath, + KeyPath: keyPath, + }, + }, + } + + // Load initial certificate + err = h.reloadTLSCertificate() + require.NoError(t, err) + + // Get initial certificate + initialCert, err := h.getTLSCertificate(nil) + require.NoError(t, err) + require.NotNil(t, initialCert) + + // Create and write a NEW certificate (simulating cert renewal) + certPEM2, keyPEM2, err := createTestCertificate("renewed.example.com") + require.NoError(t, err) + + err = os.WriteFile(certPath, certPEM2, 0o600) + require.NoError(t, err) + err = os.WriteFile(keyPath, keyPEM2, 0o600) + require.NoError(t, err) + + // Reload the certificate (simulates SIGHUP handler) + err = h.reloadTLSCertificate() + require.NoError(t, err) + + // Get reloaded certificate + reloadedCert, err := h.getTLSCertificate(nil) + require.NoError(t, err) + require.NotNil(t, reloadedCert) + + // Verify certificates are different (reload worked) + assert.NotEqual(t, initialCert.Certificate, reloadedCert.Certificate, + "reloaded certificate should be different from initial certificate") + + // Verify reloaded cert matches the new file + expectedCert, err := tls.LoadX509KeyPair(certPath, keyPath) + require.NoError(t, err) + assert.Equal(t, expectedCert.Certificate, reloadedCert.Certificate) +} + +func TestReloadTLSCertificate_InvalidPath(t *testing.T) { + h := &Headscale{ + cfg: &types.Config{ + TLS: types.TLSConfig{ + CertPath: "/nonexistent/path/cert.pem", + KeyPath: "/nonexistent/path/key.pem", + }, + }, + } + + err := h.reloadTLSCertificate() + require.Error(t, err) + assert.Contains(t, err.Error(), "loading TLS certificate") +} + +func TestReloadTLSCertificate_InvalidCertContent(t *testing.T) { + tmpDir := t.TempDir() + + certPath := filepath.Join(tmpDir, "cert.pem") + keyPath := filepath.Join(tmpDir, "key.pem") + + // Write invalid certificate content + err := os.WriteFile(certPath, []byte("not a valid certificate"), 0o600) + require.NoError(t, err) + err = os.WriteFile(keyPath, []byte("not a valid key"), 0o600) + require.NoError(t, err) + + h := &Headscale{ + cfg: &types.Config{ + TLS: types.TLSConfig{ + CertPath: certPath, + KeyPath: keyPath, + }, + }, + } + + err = h.reloadTLSCertificate() + require.Error(t, err) + assert.Contains(t, err.Error(), "loading TLS certificate") +} + +func TestReloadTLSCertificate_MismatchedCertAndKey(t *testing.T) { + tmpDir := t.TempDir() + + // Create two different certificates + certPEM1, _, err := createTestCertificate("cert1.example.com") + require.NoError(t, err) + + _, keyPEM2, err := createTestCertificate("cert2.example.com") + require.NoError(t, err) + + // Write cert from first pair and key from second pair (mismatched) + certPath := filepath.Join(tmpDir, "cert.pem") + keyPath := filepath.Join(tmpDir, "key.pem") + + err = os.WriteFile(certPath, certPEM1, 0o600) + require.NoError(t, err) + err = os.WriteFile(keyPath, keyPEM2, 0o600) + require.NoError(t, err) + + h := &Headscale{ + cfg: &types.Config{ + TLS: types.TLSConfig{ + CertPath: certPath, + KeyPath: keyPath, + }, + }, + } + + err = h.reloadTLSCertificate() + require.Error(t, err) + assert.Contains(t, err.Error(), "loading TLS certificate") +} + +func TestGetTLSCertificate_BeforeLoad(t *testing.T) { + h := &Headscale{ + cfg: &types.Config{}, + } + + // Before any certificate is loaded, getTLSCertificate should return nil + cert, err := h.getTLSCertificate(nil) + require.NoError(t, err) + assert.Nil(t, cert) +} + +func TestReloadTLSCertificate_ConcurrentAccess(t *testing.T) { + tmpDir := t.TempDir() + + certPEM, keyPEM, err := createTestCertificate("concurrent.example.com") + require.NoError(t, err) + + certPath, keyPath := writeCertFiles(t, tmpDir, certPEM, keyPEM) + + h := &Headscale{ + cfg: &types.Config{ + TLS: types.TLSConfig{ + CertPath: certPath, + KeyPath: keyPath, + }, + }, + } + + // Initial load + err = h.reloadTLSCertificate() + require.NoError(t, err) + + // Run concurrent readers and writers + var wg sync.WaitGroup + const numReaders = 100 + const numReloads = 10 + + // Start readers + for range numReaders { + wg.Add(1) + go func() { + defer wg.Done() + for range 100 { + cert, err := h.getTLSCertificate(nil) + assert.NoError(t, err) + assert.NotNil(t, cert) + } + }() + } + + // Start writers (reloaders) + for range numReloads { + wg.Add(1) + go func() { + defer wg.Done() + for range 10 { + err := h.reloadTLSCertificate() + assert.NoError(t, err) + } + }() + } + + wg.Wait() + + // Final verification that certificate is still accessible + cert, err := h.getTLSCertificate(nil) + require.NoError(t, err) + require.NotNil(t, cert) +} diff --git a/integration/control.go b/integration/control.go index f390d08024..59ed578131 100644 --- a/integration/control.go +++ b/integration/control.go @@ -46,4 +46,5 @@ type ControlServer interface { DebugBatcher() (*hscontrol.DebugBatcherInfo, error) DebugNodeStore() (map[types.NodeID]types.Node, error) DebugFilter() ([]tailcfg.FilterRule, error) + Reload() error } diff --git a/integration/tls_test.go b/integration/tls_test.go new file mode 100644 index 0000000000..0405e9f486 --- /dev/null +++ b/integration/tls_test.go @@ -0,0 +1,228 @@ +package integration + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "net/http" + "testing" + "time" + + "github.com/juanfont/headscale/integration/hsic" + "github.com/juanfont/headscale/integration/integrationutil" + "github.com/juanfont/headscale/integration/tsic" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + tlsCertPath = "/etc/headscale/tls.cert" + tlsKeyPath = "/etc/headscale/tls.key" +) + +// getTLSCertificate connects to the given HTTPS endpoint and returns +// the server's TLS certificate. +func getTLSCertificate(endpoint string) (*x509.Certificate, error) { + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, //nolint:gosec + }, + }, + Timeout: 5 * time.Second, + } + + resp, err := client.Get(endpoint + "/health") + if err != nil { + return nil, fmt.Errorf("connecting to endpoint: %w", err) + } + defer resp.Body.Close() + + if resp.TLS == nil || len(resp.TLS.PeerCertificates) == 0 { + return nil, fmt.Errorf("no TLS certificates received") + } + + return resp.TLS.PeerCertificates[0], nil +} + +// TestTLSCertificateReloadOnSIGHUP tests that headscale reloads TLS certificates +// when it receives a SIGHUP signal. +func TestTLSCertificateReloadOnSIGHUP(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + // Create headscale with TLS enabled + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithTestName("tls-reload"), + hsic.WithTLS(), + hsic.WithEmbeddedDERPServerOnly(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + // Wait for headscale to be ready and get the endpoint + endpoint := headscale.GetEndpoint() + require.Contains(t, endpoint, "https://", "endpoint should be HTTPS when TLS is enabled") + + // Get the initial certificate + var initialCert *x509.Certificate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + cert, err := getTLSCertificate(endpoint) + assert.NoError(c, err) + assert.NotNil(c, cert) + initialCert = cert + }, 10*time.Second, 500*time.Millisecond, "should be able to get initial TLS certificate") + + t.Logf("Initial certificate NotBefore: %s", initialCert.NotBefore.Format(time.RFC3339Nano)) + + // Wait a bit to ensure the new certificate will have a different NotBefore time + time.Sleep(1 * time.Second) + + // Generate a new certificate (will have a different NotBefore time) + newCert, newKey, err := integrationutil.CreateCertificate(headscale.GetHostname()) + require.NoError(t, err) + + // Write the new certificate files to the container + err = headscale.WriteFile(tlsCertPath, newCert) + require.NoError(t, err, "failed to write new certificate") + + err = headscale.WriteFile(tlsKeyPath, newKey) + require.NoError(t, err, "failed to write new key") + + t.Log("New certificate written to container, sending SIGHUP...") + + // Send SIGHUP to trigger certificate reload + err = headscale.Reload() + require.NoError(t, err, "failed to send SIGHUP") + + // Wait a moment for the reload to take effect + time.Sleep(500 * time.Millisecond) + + // Verify the new certificate is being served by checking NotBefore time changed + var newCertFromServer *x509.Certificate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + cert, err := getTLSCertificate(endpoint) + assert.NoError(c, err) + assert.NotNil(c, cert) + newCertFromServer = cert + + // The NotBefore time should be different (later) than the initial one + assert.True(c, cert.NotBefore.After(initialCert.NotBefore), + "new certificate NotBefore (%s) should be after initial (%s)", + cert.NotBefore.Format(time.RFC3339Nano), + initialCert.NotBefore.Format(time.RFC3339Nano)) + }, 10*time.Second, 500*time.Millisecond, "certificate should be reloaded after SIGHUP") + + t.Logf("New certificate NotBefore: %s", newCertFromServer.NotBefore.Format(time.RFC3339Nano)) + t.Log("TLS certificate reload verified successfully") +} + +// TestTLSCertificateReloadClientConnectivity tests that clients remain +// connected and functional after a TLS certificate reload. +func TestTLSCertificateReloadClientConnectivity(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: 2, + Users: []string{"user1"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + // Create headscale with TLS enabled + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithTestName("tls-reload-conn"), + hsic.WithTLS(), + hsic.WithEmbeddedDERPServerOnly(), + ) + requireNoErrHeadscaleEnv(t, err) + + // Wait for clients to sync + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + allClients, err := scenario.ListTailscaleClients() + require.NoError(t, err) + require.Len(t, allClients, 2, "should have 2 clients") + + // Verify clients can ping each other before certificate reload + allIPs, err := scenario.ListTailscaleClientsIPs() + require.NoError(t, err) + + t.Log("Verifying initial connectivity...") + for _, client := range allClients { + for _, ip := range allIPs { + err := client.Ping(ip.String()) + require.NoError(t, err, "initial ping failed") + } + } + + // Get endpoint and initial certificate + endpoint := headscale.GetEndpoint() + initialCert, err := getTLSCertificate(endpoint) + require.NoError(t, err) + + t.Logf("Initial certificate NotBefore: %s", initialCert.NotBefore.Format(time.RFC3339Nano)) + + // Wait to ensure new certificate will have different NotBefore + time.Sleep(1 * time.Second) + + // Generate and write new certificate + newCert, newKey, err := integrationutil.CreateCertificate(headscale.GetHostname()) + require.NoError(t, err) + + err = headscale.WriteFile(tlsCertPath, newCert) + require.NoError(t, err) + + err = headscale.WriteFile(tlsKeyPath, newKey) + require.NoError(t, err) + + t.Log("Sending SIGHUP to reload certificate...") + err = headscale.Reload() + require.NoError(t, err) + + // Wait for reload to take effect + time.Sleep(1 * time.Second) + + // Verify certificate changed + var newCertFromServer *x509.Certificate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + cert, err := getTLSCertificate(endpoint) + assert.NoError(c, err) + newCertFromServer = cert + assert.True(c, cert.NotBefore.After(initialCert.NotBefore), + "certificate should have changed") + }, 10*time.Second, 500*time.Millisecond, "certificate should be reloaded") + + t.Logf("New certificate NotBefore: %s", newCertFromServer.NotBefore.Format(time.RFC3339Nano)) + + // Verify clients can still ping each other after certificate reload + t.Log("Verifying connectivity after certificate reload...") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + for _, client := range allClients { + for _, ip := range allIPs { + err := client.Ping(ip.String()) + assert.NoError(c, err, "ping after certificate reload failed") + } + } + }, 30*time.Second, 1*time.Second, "clients should remain connected after certificate reload") + + t.Log("Client connectivity verified after TLS certificate reload") +}