1package daemon
2
3import (
4 "bytes"
5 "context"
6 "errors"
7 "fmt"
8 "io"
9 "log"
10 "net"
11 "os"
12 "strings"
13 "testing"
14
15 "github.com/charmbracelet/soft-serve/server/backend"
16 "github.com/charmbracelet/soft-serve/server/backend/sqlite"
17 "github.com/charmbracelet/soft-serve/server/cache"
18 _ "github.com/charmbracelet/soft-serve/server/cache/noop" // noop driver
19 "github.com/charmbracelet/soft-serve/server/git"
20 "github.com/charmbracelet/soft-serve/server/test"
21 "github.com/go-git/go-git/v5/plumbing/format/pktline"
22)
23
24var testDaemon *GitDaemon
25
26func TestMain(m *testing.M) {
27 tmp, err := os.MkdirTemp("", "soft-serve-test")
28 if err != nil {
29 log.Fatal(err)
30 }
31 defer os.RemoveAll(tmp)
32 os.Setenv("SOFT_SERVE_DATA_PATH", tmp)
33 os.Setenv("SOFT_SERVE_GIT_MAX_CONNECTIONS", "3")
34 os.Setenv("SOFT_SERVE_GIT_MAX_TIMEOUT", "100")
35 os.Setenv("SOFT_SERVE_GIT_IDLE_TIMEOUT", "1")
36 os.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", test.RandomPort()))
37 ctx := context.TODO()
38 ca, _ := cache.New(ctx, "noop")
39 ctx = cache.WithContext(ctx, ca)
40 fb, err := sqlite.NewSqliteBackend(ctx)
41 if err != nil {
42 log.Fatal(err)
43 }
44 ctx = backend.WithContext(ctx, fb)
45 d, err := NewGitDaemon(ctx)
46 if err != nil {
47 log.Fatal(err)
48 }
49 testDaemon = d
50 go func() {
51 if err := d.Start(); err != ErrServerClosed {
52 log.Fatal(err)
53 }
54 }()
55 code := m.Run()
56 os.Unsetenv("SOFT_SERVE_DATA_PATH")
57 os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
58 os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
59 os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
60 os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
61 _ = d.Close()
62 _ = fb.Close()
63 os.Exit(code)
64}
65
66func TestIdleTimeout(t *testing.T) {
67 c, err := net.Dial("tcp", testDaemon.addr)
68 if err != nil {
69 t.Fatal(err)
70 }
71 out, err := readPktline(c)
72 if err != nil && !errors.Is(err, io.EOF) {
73 t.Fatalf("expected nil, got error: %v", err)
74 }
75 if out != git.ErrTimeout.Error() && out != "" {
76 t.Fatalf("expected %q error, got %q", git.ErrTimeout, out)
77 }
78}
79
80func TestInvalidRepo(t *testing.T) {
81 c, err := net.Dial("tcp", testDaemon.addr)
82 if err != nil {
83 t.Fatal(err)
84 }
85 if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
86 t.Fatalf("expected nil, got error: %v", err)
87 }
88 out, err := readPktline(c)
89 if err != nil {
90 t.Fatalf("expected nil, got error: %v", err)
91 }
92 if out != git.ErrNotExist.Error() {
93 t.Fatalf("expected %q error, got %q", git.ErrNotExist, out)
94 }
95}
96
97func readPktline(c net.Conn) (string, error) {
98 buf, err := io.ReadAll(c)
99 if err != nil {
100 return "", err
101 }
102 pktout := pktline.NewScanner(bytes.NewReader(buf))
103 if !pktout.Scan() {
104 return "", pktout.Err()
105 }
106 return strings.TrimSpace(string(pktout.Bytes())), nil
107}