1package daemon
2
3import (
4 "bytes"
5 "io"
6 "log"
7 "net"
8 "os"
9 "testing"
10
11 appCfg "github.com/charmbracelet/soft-serve/config"
12 "github.com/charmbracelet/soft-serve/server/config"
13 "github.com/charmbracelet/soft-serve/server/git"
14 "github.com/go-git/go-git/v5/plumbing/format/pktline"
15)
16
17var testDaemon *Daemon
18
19func TestMain(m *testing.M) {
20 tmp, err := os.MkdirTemp("", "soft-serve-test")
21 if err != nil {
22 log.Fatal(err)
23 }
24 defer os.RemoveAll(tmp)
25 cfg := &config.Config{
26 Host: "",
27 DataPath: tmp,
28 Git: config.GitConfig{
29 // Reduce the max timeout to 100 second so we can test the timeout.
30 MaxTimeout: 100,
31 // Reduce the max read timeout to 1 second so we can test the timeout.
32 MaxReadTimeout: 1,
33 // Reduce the max connections to 3 so we can test the timeout.
34 MaxConnections: 3,
35 Port: 9418,
36 },
37 }
38 ac, err := appCfg.NewConfig(cfg)
39 if err != nil {
40 log.Fatal(err)
41 }
42 d, err := NewDaemon(cfg, ac)
43 if err != nil {
44 log.Fatal(err)
45 }
46 testDaemon = d
47 go func() {
48 if err := d.Start(); err != ErrServerClosed {
49 log.Fatal(err)
50 }
51 }()
52 defer d.Close()
53 os.Exit(m.Run())
54}
55
56func TestMaxReadTimeout(t *testing.T) {
57 c, err := net.Dial("tcp", testDaemon.addr)
58 if err != nil {
59 t.Fatal(err)
60 }
61 out, err := readPktline(c)
62 if err != nil {
63 t.Fatalf("expected nil, got error: %v", err)
64 }
65 if out != git.ErrMaxTimeout.Error() {
66 t.Fatalf("expected %q error, got nil", git.ErrMaxTimeout)
67 }
68}
69
70func TestInvalidRepo(t *testing.T) {
71 c, err := net.Dial("tcp", testDaemon.addr)
72 if err != nil {
73 t.Fatal(err)
74 }
75 if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
76 t.Fatalf("expected nil, got error: %v", err)
77 }
78 out, err := readPktline(c)
79 if err != nil {
80 t.Fatalf("expected nil, got error: %v", err)
81 }
82 if out != git.ErrInvalidRepo.Error() {
83 t.Fatalf("expected %q error, got nil", git.ErrInvalidRepo)
84 }
85}
86
87func readPktline(c net.Conn) (string, error) {
88 buf, err := io.ReadAll(c)
89 if err != nil {
90 return "", err
91 }
92 pktout := pktline.NewScanner(bytes.NewReader(buf))
93 if !pktout.Scan() {
94 return "", pktout.Err()
95 }
96 return string(pktout.Bytes()), nil
97}