From 28c48548903dcbdf5d73438bce920af34bd66fae Mon Sep 17 00:00:00 2001 From: Harsh Mantri <24585585+cheesyhypocrisy@users.noreply.github.com> Date: Fri, 9 Jan 2026 08:40:52 +0530 Subject: [PATCH] feat: add support for certificate reloading upon SIGHUP (#710) * feat: add support for certificate reloading upon SIGHUP * fix: support certificate reloading for unix and add test * fix(cmd): move cert reloader logic to the serve package --------- Co-authored-by: Ayman Bagabas --- cmd/soft/serve/certreloader.go | 54 +++++++++++++ cmd/soft/serve/certreloader_test.go | 116 ++++++++++++++++++++++++++++ cmd/soft/serve/serve.go | 23 ++++-- cmd/soft/serve/server.go | 21 +++++ pkg/web/http.go | 10 ++- 5 files changed, 216 insertions(+), 8 deletions(-) create mode 100644 cmd/soft/serve/certreloader.go create mode 100644 cmd/soft/serve/certreloader_test.go diff --git a/cmd/soft/serve/certreloader.go b/cmd/soft/serve/certreloader.go new file mode 100644 index 0000000000000000000000000000000000000000..34dc4d9b3bc5ab6e0ea6558ebf89a2cebb8c6b0f --- /dev/null +++ b/cmd/soft/serve/certreloader.go @@ -0,0 +1,54 @@ +package serve + +import ( + "crypto/tls" + "sync" + + "charm.land/log/v2" +) + +// CertReloader is responsible for reloading TLS certificates when a SIGHUP signal is received. +type CertReloader struct { + certMu sync.RWMutex + cert *tls.Certificate + certPath string + keyPath string +} + +// NewCertReloader creates a new CertReloader that watches for SIGHUP signals. +func NewCertReloader(certPath, keyPath string, logger *log.Logger) (*CertReloader, error) { + reloader := &CertReloader{ + certPath: certPath, + keyPath: keyPath, + } + + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, err + } + reloader.cert = &cert + + return reloader, nil +} + +// Reload attempts to reload the certificate and key. +func (cr *CertReloader) Reload() error { + newCert, err := tls.LoadX509KeyPair(cr.certPath, cr.keyPath) + if err != nil { + return err + } + + cr.certMu.Lock() + defer cr.certMu.Unlock() + cr.cert = &newCert + return nil +} + +// GetCertificateFunc returns a function that can be used with tls.Config.GetCertificate. +func (cr *CertReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + cr.certMu.RLock() + defer cr.certMu.RUnlock() + return cr.cert, nil + } +} diff --git a/cmd/soft/serve/certreloader_test.go b/cmd/soft/serve/certreloader_test.go new file mode 100644 index 0000000000000000000000000000000000000000..e22fcf790c1e80fd8c3d9827ddd57f74e1b49c06 --- /dev/null +++ b/cmd/soft/serve/certreloader_test.go @@ -0,0 +1,116 @@ +//go:build unix + +package serve + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "os" + "os/signal" + "path/filepath" + "syscall" + "testing" + "time" + + "charm.land/log/v2" +) + +func generateTestCert(t *testing.T, certPath, keyPath, cn string) { + t.Helper() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + + template := x509.Certificate{ + SerialNumber: nil, + Subject: pkix.Name{ + CommonName: cn, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + } + + certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if err != nil { + t.Fatal(err) + } + + certFile, err := os.Create(certPath) + if err != nil { + t.Fatal(err) + } + defer certFile.Close() + + pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes}) + + keyFile, err := os.Create(keyPath) + if err != nil { + t.Fatal(err) + } + defer keyFile.Close() + + pem.Encode(keyFile, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) +} + +func TestCertReloader(t *testing.T) { + dir := t.TempDir() + certPath := filepath.Join(dir, "/cert.pem") + keyPath := filepath.Join(dir, "/key.pem") + + // Initial cert + generateTestCert(t, certPath, keyPath, "cert-v1") + + logger := log.New(os.Stderr) + + certReloader, err := NewCertReloader(certPath, keyPath, logger) + if err != nil { + t.Fatalf("failed to create reloader: %v", err) + } + + go func() { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGHUP) + for range sigCh { + if err := certReloader.Reload(); err != nil { + logger.Error("failed to reload certificate", "err", err) + } else { + logger.Info("certificate reloaded successfully") + } + } + }() + + getCert := certReloader.GetCertificateFunc() + + cert1, err := getCert(nil) + if err != nil { + t.Fatal(err) + } + + // Replace cert on disk + generateTestCert(t, certPath, keyPath, "cert-v2") + + // Trigger reload + if err := syscall.Kill(os.Getpid(), syscall.SIGHUP); err != nil { + t.Fatalf("failed to send SIGHUP: %v", err) + } + + // Allow async goroutine to reload + time.Sleep(100 * time.Millisecond) + + cert2, err := getCert(nil) + if err != nil { + t.Fatal(err) + } + + if cert1 == cert2 { + t.Fatal("certificate was not reloaded after SIGHUP") + } +} diff --git a/cmd/soft/serve/serve.go b/cmd/soft/serve/serve.go index 21a9f080ae16065fb5edcc32b46c0b8f2e4de441..7472f3d0be14300f5dcb827197308abfcc6d9ff6 100644 --- a/cmd/soft/serve/serve.go +++ b/cmd/soft/serve/serve.go @@ -86,7 +86,7 @@ var ( done := make(chan os.Signal, 1) doneOnce := sync.OnceFunc(func() { close(done) }) - signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) // This endpoint is added for testing purposes // It allows us to stop the server from the test suite. @@ -107,12 +107,23 @@ var ( doneOnce() }() - select { - case err := <-lch: - if err != nil { - return fmt.Errorf("server error: %w", err) + for { + select { + case err := <-lch: + if err != nil { + return fmt.Errorf("server error: %w", err) + } + case sig := <-done: + if sig == syscall.SIGHUP { + s.logger.Info("received SIGHUP signal, reloading TLS certificates if enabled") + if err := s.ReloadCertificates(); err != nil { + s.logger.Error("failed to reload TLS certificates", "err", err) + } + continue + } } - case <-done: + + break } ctx, cancel := context.WithTimeout(ctx, 5*time.Second) diff --git a/cmd/soft/serve/server.go b/cmd/soft/serve/server.go index 4260f7576da3fccc12f5b296255f70490ed72682..fda09005097dce89960b6cd6af26431ceca54c32 100644 --- a/cmd/soft/serve/server.go +++ b/cmd/soft/serve/server.go @@ -2,6 +2,7 @@ package serve import ( "context" + "crypto/tls" "errors" "fmt" "net/http" @@ -27,6 +28,7 @@ type Server struct { GitDaemon *daemon.GitDaemon HTTPServer *web.HTTPServer StatsServer *stats.StatsServer + CertLoader *CertReloader Cron *cron.Scheduler Config *config.Config Backend *backend.Backend @@ -87,9 +89,28 @@ func NewServer(ctx context.Context) (*Server, error) { return nil, fmt.Errorf("create stats server: %w", err) } + if cfg.HTTP.TLSKeyPath != "" && cfg.HTTP.TLSCertPath != "" { + srv.CertLoader, err = NewCertReloader(cfg.HTTP.TLSCertPath, cfg.HTTP.TLSKeyPath, logger) + if err != nil { + return nil, fmt.Errorf("create cert reloader: %w", err) + } + + srv.HTTPServer.SetTLSConfig(&tls.Config{ + GetCertificate: srv.CertLoader.GetCertificateFunc(), + }) + } + return srv, nil } +// ReloadCertificates reloads the TLS certificates for the HTTP server. +func (s *Server) ReloadCertificates() error { + if s.CertLoader == nil { + return nil + } + return s.CertLoader.Reload() +} + // Start starts the SSH server. func (s *Server) Start() error { errg, _ := errgroup.WithContext(s.ctx) diff --git a/pkg/web/http.go b/pkg/web/http.go index 7bb255f4007d471fb83bf604357d3d60a8e192c7..531d02bfd89c4fbbf71caf64d02eaebe44585bc5 100644 --- a/pkg/web/http.go +++ b/pkg/web/http.go @@ -2,6 +2,7 @@ package web import ( "context" + "crypto/tls" "net/http" "time" @@ -37,6 +38,11 @@ func NewHTTPServer(ctx context.Context) (*HTTPServer, error) { return s, nil } +// SetTLSConfig sets the TLS configuration for the HTTP server. +func (s *HTTPServer) SetTLSConfig(tlsConfig *tls.Config) { + s.Server.TLSConfig = tlsConfig +} + // Close closes the HTTP server. func (s *HTTPServer) Close() error { return s.Server.Close() @@ -44,8 +50,8 @@ func (s *HTTPServer) Close() error { // ListenAndServe starts the HTTP server. func (s *HTTPServer) ListenAndServe() error { - if s.cfg.HTTP.TLSKeyPath != "" && s.cfg.HTTP.TLSCertPath != "" { - return s.Server.ListenAndServeTLS(s.cfg.HTTP.TLSCertPath, s.cfg.HTTP.TLSKeyPath) + if s.Server.TLSConfig != nil { + return s.Server.ListenAndServeTLS("", "") } return s.Server.ListenAndServe() }