1package daemon
2
3import (
4 "bytes"
5 "io"
6 "log"
7 "net"
8 "os"
9 "testing"
10
11 "github.com/charmbracelet/soft-serve/proto"
12 "github.com/charmbracelet/soft-serve/server/config"
13 "github.com/charmbracelet/soft-serve/server/db/fakedb"
14 "github.com/charmbracelet/soft-serve/server/git"
15 "github.com/go-git/go-git/v5/plumbing/format/pktline"
16)
17
18var testDaemon *Daemon
19
20func TestMain(m *testing.M) {
21 tmp, err := os.MkdirTemp("", "soft-serve-test")
22 if err != nil {
23 log.Fatal(err)
24 }
25 defer os.RemoveAll(tmp)
26 cfg := &config.Config{
27 Host: "",
28 DataPath: tmp,
29 AnonAccess: proto.ReadOnlyAccess,
30 Git: config.GitConfig{
31 // Reduce the max read timeout to 1 second so we can test the timeout.
32 IdleTimeout: 3,
33 // Reduce the max timeout to 100 second so we can test the timeout.
34 MaxTimeout: 100,
35 // Reduce the max connections to 3 so we can test the timeout.
36 MaxConnections: 3,
37 Port: 9418,
38 },
39 }
40 cfg = cfg.WithDB(&fakedb.FakeDB{})
41 d, err := NewDaemon(cfg)
42 if err != nil {
43 log.Fatal(err)
44 }
45 testDaemon = d
46 go func() {
47 if err := d.Start(); err != ErrServerClosed {
48 log.Fatal(err)
49 }
50 }()
51 defer d.Close()
52 os.Exit(m.Run())
53}
54
55func TestIdleTimeout(t *testing.T) {
56 c, err := net.Dial("tcp", testDaemon.addr)
57 if err != nil {
58 t.Fatal(err)
59 }
60 out, err := readPktline(c)
61 if err != nil {
62 t.Fatalf("expected nil, got error: %v", err)
63 }
64 if out != git.ErrTimeout.Error() {
65 t.Fatalf("expected %q error, got %q", git.ErrTimeout, out)
66 }
67}
68
69func TestInvalidRepo(t *testing.T) {
70 c, err := net.Dial("tcp", testDaemon.addr)
71 if err != nil {
72 t.Fatal(err)
73 }
74 if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
75 t.Fatalf("expected nil, got error: %v", err)
76 }
77 out, err := readPktline(c)
78 if err != nil {
79 t.Fatalf("expected nil, got error: %v", err)
80 }
81 if out != git.ErrInvalidRepo.Error() {
82 t.Fatalf("expected %q error, got %q", git.ErrInvalidRepo, out)
83 }
84}
85
86func readPktline(c net.Conn) (string, error) {
87 buf, err := io.ReadAll(c)
88 if err != nil {
89 return "", err
90 }
91 pktout := pktline.NewScanner(bytes.NewReader(buf))
92 if !pktout.Scan() {
93 return "", pktout.Err()
94 }
95 return string(pktout.Bytes()), nil
96}