feat: add support for certificate reloading upon SIGHUP (#710)

Harsh Mantri and Ayman Bagabas created

* feat: add support for certificate reloading upon SIGHUP

* fix: support certificate reloading for unix and add test

* fix(cmd): move cert reloader logic to the serve package

---------

Co-authored-by: Ayman Bagabas <ayman.bagabas@gmail.com>

Change summary

cmd/soft/serve/certreloader.go      |  54 ++++++++++++++
cmd/soft/serve/certreloader_test.go | 116 +++++++++++++++++++++++++++++++
cmd/soft/serve/serve.go             |  23 ++++-
cmd/soft/serve/server.go            |  21 +++++
pkg/web/http.go                     |  10 ++
5 files changed, 216 insertions(+), 8 deletions(-)

Detailed changes

cmd/soft/serve/certreloader.go 🔗

@@ -0,0 +1,54 @@
+package serve
+
+import (
+	"crypto/tls"
+	"sync"
+
+	"charm.land/log/v2"
+)
+
+// CertReloader is responsible for reloading TLS certificates when a SIGHUP signal is received.
+type CertReloader struct {
+	certMu   sync.RWMutex
+	cert     *tls.Certificate
+	certPath string
+	keyPath  string
+}
+
+// NewCertReloader creates a new CertReloader that watches for SIGHUP signals.
+func NewCertReloader(certPath, keyPath string, logger *log.Logger) (*CertReloader, error) {
+	reloader := &CertReloader{
+		certPath: certPath,
+		keyPath:  keyPath,
+	}
+
+	cert, err := tls.LoadX509KeyPair(certPath, keyPath)
+	if err != nil {
+		return nil, err
+	}
+	reloader.cert = &cert
+
+	return reloader, nil
+}
+
+// Reload attempts to reload the certificate and key.
+func (cr *CertReloader) Reload() error {
+	newCert, err := tls.LoadX509KeyPair(cr.certPath, cr.keyPath)
+	if err != nil {
+		return err
+	}
+
+	cr.certMu.Lock()
+	defer cr.certMu.Unlock()
+	cr.cert = &newCert
+	return nil
+}
+
+// GetCertificateFunc returns a function that can be used with tls.Config.GetCertificate.
+func (cr *CertReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
+	return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
+		cr.certMu.RLock()
+		defer cr.certMu.RUnlock()
+		return cr.cert, nil
+	}
+}

cmd/soft/serve/certreloader_test.go 🔗

@@ -0,0 +1,116 @@
+//go:build unix
+
+package serve
+
+import (
+	"crypto/rand"
+	"crypto/rsa"
+	"crypto/x509"
+	"crypto/x509/pkix"
+	"encoding/pem"
+	"os"
+	"os/signal"
+	"path/filepath"
+	"syscall"
+	"testing"
+	"time"
+
+	"charm.land/log/v2"
+)
+
+func generateTestCert(t *testing.T, certPath, keyPath, cn string) {
+	t.Helper()
+
+	privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	template := x509.Certificate{
+		SerialNumber: nil,
+		Subject: pkix.Name{
+			CommonName: cn,
+		},
+		NotBefore: time.Now(),
+		NotAfter:  time.Now().Add(time.Hour),
+	}
+
+	certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	certFile, err := os.Create(certPath)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer certFile.Close()
+
+	pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})
+
+	keyFile, err := os.Create(keyPath)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer keyFile.Close()
+
+	pem.Encode(keyFile, &pem.Block{
+		Type:  "RSA PRIVATE KEY",
+		Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
+	})
+}
+
+func TestCertReloader(t *testing.T) {
+	dir := t.TempDir()
+	certPath := filepath.Join(dir, "/cert.pem")
+	keyPath := filepath.Join(dir, "/key.pem")
+
+	// Initial cert
+	generateTestCert(t, certPath, keyPath, "cert-v1")
+
+	logger := log.New(os.Stderr)
+
+	certReloader, err := NewCertReloader(certPath, keyPath, logger)
+	if err != nil {
+		t.Fatalf("failed to create reloader: %v", err)
+	}
+
+	go func() {
+		sigCh := make(chan os.Signal, 1)
+		signal.Notify(sigCh, syscall.SIGHUP)
+		for range sigCh {
+			if err := certReloader.Reload(); err != nil {
+				logger.Error("failed to reload certificate", "err", err)
+			} else {
+				logger.Info("certificate reloaded successfully")
+			}
+		}
+	}()
+
+	getCert := certReloader.GetCertificateFunc()
+
+	cert1, err := getCert(nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// Replace cert on disk
+	generateTestCert(t, certPath, keyPath, "cert-v2")
+
+	// Trigger reload
+	if err := syscall.Kill(os.Getpid(), syscall.SIGHUP); err != nil {
+		t.Fatalf("failed to send SIGHUP: %v", err)
+	}
+
+	// Allow async goroutine to reload
+	time.Sleep(100 * time.Millisecond)
+
+	cert2, err := getCert(nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if cert1 == cert2 {
+		t.Fatal("certificate was not reloaded after SIGHUP")
+	}
+}

cmd/soft/serve/serve.go 🔗

@@ -86,7 +86,7 @@ var (
 			done := make(chan os.Signal, 1)
 			doneOnce := sync.OnceFunc(func() { close(done) })
 
-			signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
+			signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
 
 			// This endpoint is added for testing purposes
 			// It allows us to stop the server from the test suite.
@@ -107,12 +107,23 @@ var (
 				doneOnce()
 			}()
 
-			select {
-			case err := <-lch:
-				if err != nil {
-					return fmt.Errorf("server error: %w", err)
+			for {
+				select {
+				case err := <-lch:
+					if err != nil {
+						return fmt.Errorf("server error: %w", err)
+					}
+				case sig := <-done:
+					if sig == syscall.SIGHUP {
+						s.logger.Info("received SIGHUP signal, reloading TLS certificates if enabled")
+						if err := s.ReloadCertificates(); err != nil {
+							s.logger.Error("failed to reload TLS certificates", "err", err)
+						}
+						continue
+					}
 				}
-			case <-done:
+
+				break
 			}
 
 			ctx, cancel := context.WithTimeout(ctx, 5*time.Second)

cmd/soft/serve/server.go 🔗

@@ -2,6 +2,7 @@ package serve
 
 import (
 	"context"
+	"crypto/tls"
 	"errors"
 	"fmt"
 	"net/http"
@@ -27,6 +28,7 @@ type Server struct {
 	GitDaemon   *daemon.GitDaemon
 	HTTPServer  *web.HTTPServer
 	StatsServer *stats.StatsServer
+	CertLoader  *CertReloader
 	Cron        *cron.Scheduler
 	Config      *config.Config
 	Backend     *backend.Backend
@@ -87,9 +89,28 @@ func NewServer(ctx context.Context) (*Server, error) {
 		return nil, fmt.Errorf("create stats server: %w", err)
 	}
 
+	if cfg.HTTP.TLSKeyPath != "" && cfg.HTTP.TLSCertPath != "" {
+		srv.CertLoader, err = NewCertReloader(cfg.HTTP.TLSCertPath, cfg.HTTP.TLSKeyPath, logger)
+		if err != nil {
+			return nil, fmt.Errorf("create cert reloader: %w", err)
+		}
+
+		srv.HTTPServer.SetTLSConfig(&tls.Config{
+			GetCertificate: srv.CertLoader.GetCertificateFunc(),
+		})
+	}
+
 	return srv, nil
 }
 
+// ReloadCertificates reloads the TLS certificates for the HTTP server.
+func (s *Server) ReloadCertificates() error {
+	if s.CertLoader == nil {
+		return nil
+	}
+	return s.CertLoader.Reload()
+}
+
 // Start starts the SSH server.
 func (s *Server) Start() error {
 	errg, _ := errgroup.WithContext(s.ctx)

pkg/web/http.go 🔗

@@ -2,6 +2,7 @@ package web
 
 import (
 	"context"
+	"crypto/tls"
 	"net/http"
 	"time"
 
@@ -37,6 +38,11 @@ func NewHTTPServer(ctx context.Context) (*HTTPServer, error) {
 	return s, nil
 }
 
+// SetTLSConfig sets the TLS configuration for the HTTP server.
+func (s *HTTPServer) SetTLSConfig(tlsConfig *tls.Config) {
+	s.Server.TLSConfig = tlsConfig
+}
+
 // Close closes the HTTP server.
 func (s *HTTPServer) Close() error {
 	return s.Server.Close()
@@ -44,8 +50,8 @@ func (s *HTTPServer) Close() error {
 
 // ListenAndServe starts the HTTP server.
 func (s *HTTPServer) ListenAndServe() error {
-	if s.cfg.HTTP.TLSKeyPath != "" && s.cfg.HTTP.TLSCertPath != "" {
-		return s.Server.ListenAndServeTLS(s.cfg.HTTP.TLSCertPath, s.cfg.HTTP.TLSKeyPath)
+	if s.Server.TLSConfig != nil {
+		return s.Server.ListenAndServeTLS("", "")
 	}
 	return s.Server.ListenAndServe()
 }