1package daemon
2
3import (
4 "bytes"
5 "context"
6 "io"
7 "log"
8 "net"
9 "os"
10 "testing"
11
12 appCfg "github.com/charmbracelet/soft-serve/config"
13 "github.com/charmbracelet/soft-serve/server/config"
14 "github.com/charmbracelet/soft-serve/server/git"
15 "github.com/go-git/go-git/v5/plumbing/format/pktline"
16)
17
18var testDaemon *Daemon
19
20func TestMain(m *testing.M) {
21 cfg := config.DefaultConfig()
22 // Reduce the max connections to 3 so we can test the timeout.
23 cfg.GitMaxConnections = 3
24 // Reduce the max timeout to 100 second so we can test the timeout.
25 cfg.GitMaxTimeout = 100
26 // Reduce the max read timeout to 1 second so we can test the timeout.
27 cfg.GitMaxReadTimeout = 1
28 ac, err := appCfg.NewConfig(cfg)
29 if err != nil {
30 log.Fatal(err)
31 }
32 d, err := NewDaemon(cfg, ac)
33 if err != nil {
34 log.Fatal(err)
35 }
36 testDaemon = d
37 go func() {
38 if err := d.Start(); err != ErrServerClosed {
39 log.Fatal(err)
40 }
41 }()
42 defer d.Shutdown(context.Background())
43 os.Exit(m.Run())
44}
45
46func TestMaxReadTimeout(t *testing.T) {
47 c, err := net.Dial("tcp", testDaemon.addr)
48 if err != nil {
49 t.Fatal(err)
50 }
51 out, err := readPktline(c)
52 if err != nil {
53 t.Fatalf("expected nil, got error: %v", err)
54 }
55 if out != git.ErrMaxTimeout.Error() {
56 t.Fatalf("expected %q error, got nil", git.ErrMaxTimeout)
57 }
58}
59
60func TestInvalidRepo(t *testing.T) {
61 c, err := net.Dial("tcp", testDaemon.addr)
62 if err != nil {
63 t.Fatal(err)
64 }
65 if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
66 t.Fatalf("expected nil, got error: %v", err)
67 }
68 out, err := readPktline(c)
69 if err != nil {
70 t.Fatalf("expected nil, got error: %v", err)
71 }
72 if out != git.ErrInvalidRepo.Error() {
73 t.Fatalf("expected %q error, got nil", git.ErrInvalidRepo)
74 }
75}
76
77func readPktline(c net.Conn) (string, error) {
78 buf, err := io.ReadAll(c)
79 if err != nil {
80 return "", err
81 }
82 pktout := pktline.NewScanner(bytes.NewReader(buf))
83 if !pktout.Scan() {
84 return "", pktout.Err()
85 }
86 return string(pktout.Bytes()), nil
87}