1//go:build unix
2
3package serve
4
5import (
6 "crypto/rand"
7 "crypto/rsa"
8 "crypto/x509"
9 "crypto/x509/pkix"
10 "encoding/pem"
11 "os"
12 "os/signal"
13 "path/filepath"
14 "syscall"
15 "testing"
16 "time"
17
18 "charm.land/log/v2"
19)
20
21func generateTestCert(t *testing.T, certPath, keyPath, cn string) {
22 t.Helper()
23
24 privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
25 if err != nil {
26 t.Fatal(err)
27 }
28
29 template := x509.Certificate{
30 SerialNumber: nil,
31 Subject: pkix.Name{
32 CommonName: cn,
33 },
34 NotBefore: time.Now(),
35 NotAfter: time.Now().Add(time.Hour),
36 }
37
38 certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
39 if err != nil {
40 t.Fatal(err)
41 }
42
43 certFile, err := os.Create(certPath)
44 if err != nil {
45 t.Fatal(err)
46 }
47 defer certFile.Close()
48
49 pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})
50
51 keyFile, err := os.Create(keyPath)
52 if err != nil {
53 t.Fatal(err)
54 }
55 defer keyFile.Close()
56
57 pem.Encode(keyFile, &pem.Block{
58 Type: "RSA PRIVATE KEY",
59 Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
60 })
61}
62
63func TestCertReloader(t *testing.T) {
64 dir := t.TempDir()
65 certPath := filepath.Join(dir, "/cert.pem")
66 keyPath := filepath.Join(dir, "/key.pem")
67
68 // Initial cert
69 generateTestCert(t, certPath, keyPath, "cert-v1")
70
71 logger := log.New(os.Stderr)
72
73 certReloader, err := NewCertReloader(certPath, keyPath, logger)
74 if err != nil {
75 t.Fatalf("failed to create reloader: %v", err)
76 }
77
78 go func() {
79 sigCh := make(chan os.Signal, 1)
80 signal.Notify(sigCh, syscall.SIGHUP)
81 for range sigCh {
82 if err := certReloader.Reload(); err != nil {
83 logger.Error("failed to reload certificate", "err", err)
84 } else {
85 logger.Info("certificate reloaded successfully")
86 }
87 }
88 }()
89
90 getCert := certReloader.GetCertificateFunc()
91
92 cert1, err := getCert(nil)
93 if err != nil {
94 t.Fatal(err)
95 }
96
97 // Replace cert on disk
98 generateTestCert(t, certPath, keyPath, "cert-v2")
99
100 // Trigger reload
101 if err := syscall.Kill(os.Getpid(), syscall.SIGHUP); err != nil {
102 t.Fatalf("failed to send SIGHUP: %v", err)
103 }
104
105 // Allow async goroutine to reload
106 time.Sleep(100 * time.Millisecond)
107
108 cert2, err := getCert(nil)
109 if err != nil {
110 t.Fatal(err)
111 }
112
113 if cert1 == cert2 {
114 t.Fatal("certificate was not reloaded after SIGHUP")
115 }
116}