certreloader_test.go

  1//go:build unix
  2
  3package serve
  4
  5import (
  6	"crypto/rand"
  7	"crypto/rsa"
  8	"crypto/x509"
  9	"crypto/x509/pkix"
 10	"encoding/pem"
 11	"os"
 12	"os/signal"
 13	"path/filepath"
 14	"syscall"
 15	"testing"
 16	"time"
 17
 18	"charm.land/log/v2"
 19)
 20
 21func generateTestCert(t *testing.T, certPath, keyPath, cn string) {
 22	t.Helper()
 23
 24	privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
 25	if err != nil {
 26		t.Fatal(err)
 27	}
 28
 29	template := x509.Certificate{
 30		SerialNumber: nil,
 31		Subject: pkix.Name{
 32			CommonName: cn,
 33		},
 34		NotBefore: time.Now(),
 35		NotAfter:  time.Now().Add(time.Hour),
 36	}
 37
 38	certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
 39	if err != nil {
 40		t.Fatal(err)
 41	}
 42
 43	certFile, err := os.Create(certPath)
 44	if err != nil {
 45		t.Fatal(err)
 46	}
 47	defer certFile.Close()
 48
 49	pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})
 50
 51	keyFile, err := os.Create(keyPath)
 52	if err != nil {
 53		t.Fatal(err)
 54	}
 55	defer keyFile.Close()
 56
 57	pem.Encode(keyFile, &pem.Block{
 58		Type:  "RSA PRIVATE KEY",
 59		Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
 60	})
 61}
 62
 63func TestCertReloader(t *testing.T) {
 64	dir := t.TempDir()
 65	certPath := filepath.Join(dir, "/cert.pem")
 66	keyPath := filepath.Join(dir, "/key.pem")
 67
 68	// Initial cert
 69	generateTestCert(t, certPath, keyPath, "cert-v1")
 70
 71	logger := log.New(os.Stderr)
 72
 73	certReloader, err := NewCertReloader(certPath, keyPath, logger)
 74	if err != nil {
 75		t.Fatalf("failed to create reloader: %v", err)
 76	}
 77
 78	go func() {
 79		sigCh := make(chan os.Signal, 1)
 80		signal.Notify(sigCh, syscall.SIGHUP)
 81		for range sigCh {
 82			if err := certReloader.Reload(); err != nil {
 83				logger.Error("failed to reload certificate", "err", err)
 84			} else {
 85				logger.Info("certificate reloaded successfully")
 86			}
 87		}
 88	}()
 89
 90	getCert := certReloader.GetCertificateFunc()
 91
 92	cert1, err := getCert(nil)
 93	if err != nil {
 94		t.Fatal(err)
 95	}
 96
 97	// Replace cert on disk
 98	generateTestCert(t, certPath, keyPath, "cert-v2")
 99
100	// Trigger reload
101	if err := syscall.Kill(os.Getpid(), syscall.SIGHUP); err != nil {
102		t.Fatalf("failed to send SIGHUP: %v", err)
103	}
104
105	// Allow async goroutine to reload
106	time.Sleep(100 * time.Millisecond)
107
108	cert2, err := getCert(nil)
109	if err != nil {
110		t.Fatal(err)
111	}
112
113	if cert1 == cert2 {
114		t.Fatal("certificate was not reloaded after SIGHUP")
115	}
116}