1package daemon
2
3import (
4 "bytes"
5 "io"
6 "log"
7 "net"
8 "os"
9 "testing"
10
11 "github.com/charmbracelet/soft-serve/server/config"
12 "github.com/charmbracelet/soft-serve/server/git"
13 "github.com/go-git/go-git/v5/plumbing/format/pktline"
14)
15
16var testDaemon *Daemon
17
18func TestMain(m *testing.M) {
19 tmp, err := os.MkdirTemp("", "soft-serve-test")
20 if err != nil {
21 log.Fatal(err)
22 }
23 defer os.RemoveAll(tmp)
24 cfg := &config.Config{
25 Host: "",
26 DataPath: tmp,
27 Git: config.GitConfig{
28 // Reduce the max timeout to 100 second so we can test the timeout.
29 MaxTimeout: 100,
30 // Reduce the max read timeout to 1 second so we can test the timeout.
31 MaxReadTimeout: 1,
32 // Reduce the max connections to 3 so we can test the timeout.
33 MaxConnections: 3,
34 Port: 9418,
35 },
36 }
37 d, err := NewDaemon(cfg)
38 if err != nil {
39 log.Fatal(err)
40 }
41 testDaemon = d
42 go func() {
43 if err := d.Start(); err != ErrServerClosed {
44 log.Fatal(err)
45 }
46 }()
47 defer d.Close()
48 os.Exit(m.Run())
49}
50
51func TestMaxReadTimeout(t *testing.T) {
52 c, err := net.Dial("tcp", testDaemon.addr)
53 if err != nil {
54 t.Fatal(err)
55 }
56 out, err := readPktline(c)
57 if err != nil {
58 t.Fatalf("expected nil, got error: %v", err)
59 }
60 if out != git.ErrMaxTimeout.Error() {
61 t.Fatalf("expected %q error, got nil", git.ErrMaxTimeout)
62 }
63}
64
65func TestInvalidRepo(t *testing.T) {
66 c, err := net.Dial("tcp", testDaemon.addr)
67 if err != nil {
68 t.Fatal(err)
69 }
70 if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
71 t.Fatalf("expected nil, got error: %v", err)
72 }
73 out, err := readPktline(c)
74 if err != nil {
75 t.Fatalf("expected nil, got error: %v", err)
76 }
77 if out != git.ErrInvalidRepo.Error() {
78 t.Fatalf("expected %q error, got nil", git.ErrInvalidRepo)
79 }
80}
81
82func readPktline(c net.Conn) (string, error) {
83 buf, err := io.ReadAll(c)
84 if err != nil {
85 return "", err
86 }
87 pktout := pktline.NewScanner(bytes.NewReader(buf))
88 if !pktout.Scan() {
89 return "", pktout.Err()
90 }
91 return string(pktout.Bytes()), nil
92}