1package daemon
2
3import (
4 "bytes"
5 "io"
6 "log"
7 "net"
8 "os"
9 "strconv"
10 "testing"
11
12 "github.com/charmbracelet/soft-serve/server/config"
13 "github.com/charmbracelet/soft-serve/server/git"
14 "github.com/go-git/go-git/v5/plumbing/format/pktline"
15)
16
17var testDaemon *Daemon
18
19func TestMain(m *testing.M) {
20 tmp, err := os.MkdirTemp("", "soft-serve-test")
21 if err != nil {
22 log.Fatal(err)
23 }
24 defer os.RemoveAll(tmp)
25 os.Setenv("SOFT_SERVE_DATA_PATH", tmp)
26 os.Setenv("SOFT_SERVE_ANON_ACCESS", "read-only")
27 os.Setenv("SOFT_SERVE_GIT_MAX_CONNECTIONS", "3")
28 os.Setenv("SOFT_SERVE_GIT_MAX_TIMEOUT", "100")
29 os.Setenv("SOFT_SERVE_GIT_IDLE_TIMEOUT", "3")
30 os.Setenv("SOFT_SERVE_GIT_PORT", strconv.Itoa(randomPort()))
31 cfg := config.DefaultConfig()
32 d, err := NewDaemon(cfg)
33 if err != nil {
34 log.Fatal(err)
35 }
36 testDaemon = d
37 go func() {
38 if err := d.Start(); err != ErrServerClosed {
39 log.Fatal(err)
40 }
41 }()
42 defer d.Close()
43 os.Exit(m.Run())
44 os.Unsetenv("SOFT_SERVE_DATA_PATH")
45 os.Unsetenv("SOFT_SERVE_ANON_ACCESS")
46 os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
47 os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
48 os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
49 os.Unsetenv("SOFT_SERVE_GIT_PORT")
50}
51
52func TestIdleTimeout(t *testing.T) {
53 c, err := net.Dial("tcp", testDaemon.addr)
54 if err != nil {
55 t.Fatal(err)
56 }
57 out, err := readPktline(c)
58 if err != nil {
59 t.Fatalf("expected nil, got error: %v", err)
60 }
61 if out != git.ErrTimeout.Error() {
62 t.Fatalf("expected %q error, got %q", git.ErrTimeout, out)
63 }
64}
65
66func TestInvalidRepo(t *testing.T) {
67 c, err := net.Dial("tcp", testDaemon.addr)
68 if err != nil {
69 t.Fatal(err)
70 }
71 if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
72 t.Fatalf("expected nil, got error: %v", err)
73 }
74 out, err := readPktline(c)
75 if err != nil {
76 t.Fatalf("expected nil, got error: %v", err)
77 }
78 if out != git.ErrInvalidRepo.Error() {
79 t.Fatalf("expected %q error, got %q", git.ErrInvalidRepo, out)
80 }
81}
82
83func readPktline(c net.Conn) (string, error) {
84 buf, err := io.ReadAll(c)
85 if err != nil {
86 return "", err
87 }
88 pktout := pktline.NewScanner(bytes.NewReader(buf))
89 if !pktout.Scan() {
90 return "", pktout.Err()
91 }
92 return string(pktout.Bytes()), nil
93}
94
95func randomPort() int {
96 addr, _ := net.Listen("tcp", ":0") //nolint:gosec
97 _ = addr.Close()
98 return addr.Addr().(*net.TCPAddr).Port
99}