Detailed changes
@@ -0,0 +1,54 @@
+package serve
+
+import (
+ "crypto/tls"
+ "sync"
+
+ "charm.land/log/v2"
+)
+
+// CertReloader is responsible for reloading TLS certificates when a SIGHUP signal is received.
+type CertReloader struct {
+ certMu sync.RWMutex
+ cert *tls.Certificate
+ certPath string
+ keyPath string
+}
+
+// NewCertReloader creates a new CertReloader that watches for SIGHUP signals.
+func NewCertReloader(certPath, keyPath string, logger *log.Logger) (*CertReloader, error) {
+ reloader := &CertReloader{
+ certPath: certPath,
+ keyPath: keyPath,
+ }
+
+ cert, err := tls.LoadX509KeyPair(certPath, keyPath)
+ if err != nil {
+ return nil, err
+ }
+ reloader.cert = &cert
+
+ return reloader, nil
+}
+
+// Reload attempts to reload the certificate and key.
+func (cr *CertReloader) Reload() error {
+ newCert, err := tls.LoadX509KeyPair(cr.certPath, cr.keyPath)
+ if err != nil {
+ return err
+ }
+
+ cr.certMu.Lock()
+ defer cr.certMu.Unlock()
+ cr.cert = &newCert
+ return nil
+}
+
+// GetCertificateFunc returns a function that can be used with tls.Config.GetCertificate.
+func (cr *CertReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
+ return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
+ cr.certMu.RLock()
+ defer cr.certMu.RUnlock()
+ return cr.cert, nil
+ }
+}
@@ -0,0 +1,116 @@
+//go:build unix
+
+package serve
+
+import (
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/pem"
+ "os"
+ "os/signal"
+ "path/filepath"
+ "syscall"
+ "testing"
+ "time"
+
+ "charm.land/log/v2"
+)
+
+func generateTestCert(t *testing.T, certPath, keyPath, cn string) {
+ t.Helper()
+
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ template := x509.Certificate{
+ SerialNumber: nil,
+ Subject: pkix.Name{
+ CommonName: cn,
+ },
+ NotBefore: time.Now(),
+ NotAfter: time.Now().Add(time.Hour),
+ }
+
+ certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ certFile, err := os.Create(certPath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer certFile.Close()
+
+ pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})
+
+ keyFile, err := os.Create(keyPath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer keyFile.Close()
+
+ pem.Encode(keyFile, &pem.Block{
+ Type: "RSA PRIVATE KEY",
+ Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
+ })
+}
+
+func TestCertReloader(t *testing.T) {
+ dir := t.TempDir()
+ certPath := filepath.Join(dir, "/cert.pem")
+ keyPath := filepath.Join(dir, "/key.pem")
+
+ // Initial cert
+ generateTestCert(t, certPath, keyPath, "cert-v1")
+
+ logger := log.New(os.Stderr)
+
+ certReloader, err := NewCertReloader(certPath, keyPath, logger)
+ if err != nil {
+ t.Fatalf("failed to create reloader: %v", err)
+ }
+
+ go func() {
+ sigCh := make(chan os.Signal, 1)
+ signal.Notify(sigCh, syscall.SIGHUP)
+ for range sigCh {
+ if err := certReloader.Reload(); err != nil {
+ logger.Error("failed to reload certificate", "err", err)
+ } else {
+ logger.Info("certificate reloaded successfully")
+ }
+ }
+ }()
+
+ getCert := certReloader.GetCertificateFunc()
+
+ cert1, err := getCert(nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Replace cert on disk
+ generateTestCert(t, certPath, keyPath, "cert-v2")
+
+ // Trigger reload
+ if err := syscall.Kill(os.Getpid(), syscall.SIGHUP); err != nil {
+ t.Fatalf("failed to send SIGHUP: %v", err)
+ }
+
+ // Allow async goroutine to reload
+ time.Sleep(100 * time.Millisecond)
+
+ cert2, err := getCert(nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if cert1 == cert2 {
+ t.Fatal("certificate was not reloaded after SIGHUP")
+ }
+}
@@ -86,7 +86,7 @@ var (
done := make(chan os.Signal, 1)
doneOnce := sync.OnceFunc(func() { close(done) })
- signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
+ signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
// This endpoint is added for testing purposes
// It allows us to stop the server from the test suite.
@@ -107,12 +107,23 @@ var (
doneOnce()
}()
- select {
- case err := <-lch:
- if err != nil {
- return fmt.Errorf("server error: %w", err)
+ for {
+ select {
+ case err := <-lch:
+ if err != nil {
+ return fmt.Errorf("server error: %w", err)
+ }
+ case sig := <-done:
+ if sig == syscall.SIGHUP {
+ s.logger.Info("received SIGHUP signal, reloading TLS certificates if enabled")
+ if err := s.ReloadCertificates(); err != nil {
+ s.logger.Error("failed to reload TLS certificates", "err", err)
+ }
+ continue
+ }
}
- case <-done:
+
+ break
}
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
@@ -2,6 +2,7 @@ package serve
import (
"context"
+ "crypto/tls"
"errors"
"fmt"
"net/http"
@@ -27,6 +28,7 @@ type Server struct {
GitDaemon *daemon.GitDaemon
HTTPServer *web.HTTPServer
StatsServer *stats.StatsServer
+ CertLoader *CertReloader
Cron *cron.Scheduler
Config *config.Config
Backend *backend.Backend
@@ -87,9 +89,28 @@ func NewServer(ctx context.Context) (*Server, error) {
return nil, fmt.Errorf("create stats server: %w", err)
}
+ if cfg.HTTP.TLSKeyPath != "" && cfg.HTTP.TLSCertPath != "" {
+ srv.CertLoader, err = NewCertReloader(cfg.HTTP.TLSCertPath, cfg.HTTP.TLSKeyPath, logger)
+ if err != nil {
+ return nil, fmt.Errorf("create cert reloader: %w", err)
+ }
+
+ srv.HTTPServer.SetTLSConfig(&tls.Config{
+ GetCertificate: srv.CertLoader.GetCertificateFunc(),
+ })
+ }
+
return srv, nil
}
+// ReloadCertificates reloads the TLS certificates for the HTTP server.
+func (s *Server) ReloadCertificates() error {
+ if s.CertLoader == nil {
+ return nil
+ }
+ return s.CertLoader.Reload()
+}
+
// Start starts the SSH server.
func (s *Server) Start() error {
errg, _ := errgroup.WithContext(s.ctx)
@@ -2,6 +2,7 @@ package web
import (
"context"
+ "crypto/tls"
"net/http"
"time"
@@ -37,6 +38,11 @@ func NewHTTPServer(ctx context.Context) (*HTTPServer, error) {
return s, nil
}
+// SetTLSConfig sets the TLS configuration for the HTTP server.
+func (s *HTTPServer) SetTLSConfig(tlsConfig *tls.Config) {
+ s.Server.TLSConfig = tlsConfig
+}
+
// Close closes the HTTP server.
func (s *HTTPServer) Close() error {
return s.Server.Close()
@@ -44,8 +50,8 @@ func (s *HTTPServer) Close() error {
// ListenAndServe starts the HTTP server.
func (s *HTTPServer) ListenAndServe() error {
- if s.cfg.HTTP.TLSKeyPath != "" && s.cfg.HTTP.TLSCertPath != "" {
- return s.Server.ListenAndServeTLS(s.cfg.HTTP.TLSCertPath, s.cfg.HTTP.TLSKeyPath)
+ if s.Server.TLSConfig != nil {
+ return s.Server.ListenAndServeTLS("", "")
}
return s.Server.ListenAndServe()
}