1package daemon
2
3import (
4 "bytes"
5 "context"
6 "errors"
7 "fmt"
8 "io"
9 "log"
10 "net"
11 "os"
12 "strings"
13 "testing"
14
15 "github.com/charmbracelet/soft-serve/server/backend"
16 "github.com/charmbracelet/soft-serve/server/backend/sqlite"
17 "github.com/charmbracelet/soft-serve/server/cache"
18 "github.com/charmbracelet/soft-serve/server/cache/noop"
19 "github.com/charmbracelet/soft-serve/server/config"
20 "github.com/charmbracelet/soft-serve/server/git"
21 "github.com/charmbracelet/soft-serve/server/test"
22 "github.com/go-git/go-git/v5/plumbing/format/pktline"
23)
24
25var testDaemon *GitDaemon
26
27func TestMain(m *testing.M) {
28 tmp, err := os.MkdirTemp("", "soft-serve-test")
29 if err != nil {
30 log.Fatal(err)
31 }
32 defer os.RemoveAll(tmp)
33 os.Setenv("SOFT_SERVE_DATA_PATH", tmp)
34 os.Setenv("SOFT_SERVE_GIT_MAX_CONNECTIONS", "3")
35 os.Setenv("SOFT_SERVE_GIT_MAX_TIMEOUT", "100")
36 os.Setenv("SOFT_SERVE_GIT_IDLE_TIMEOUT", "1")
37 os.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", test.RandomPort()))
38 ctx := context.TODO()
39 ca, _ := noop.NewCache(ctx)
40 ctx = cache.WithContext(ctx, ca)
41 cfg := config.DefaultConfig()
42 if err := cfg.WriteConfig(); err != nil {
43 log.Fatal("failed to write default config: %w", err)
44 }
45 ctx = config.WithContext(ctx, cfg)
46 fb, err := sqlite.NewSqliteBackend(ctx)
47 if err != nil {
48 log.Fatal(err)
49 }
50 cfg = cfg.WithBackend(fb)
51 ctx = backend.WithContext(ctx, fb)
52 d, err := NewGitDaemon(ctx)
53 if err != nil {
54 log.Fatal(err)
55 }
56 testDaemon = d
57 go func() {
58 if err := d.Start(); err != ErrServerClosed {
59 log.Fatal(err)
60 }
61 }()
62 code := m.Run()
63 os.Unsetenv("SOFT_SERVE_DATA_PATH")
64 os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
65 os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
66 os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
67 os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
68 _ = d.Close()
69 _ = fb.Close()
70 os.Exit(code)
71}
72
73func TestIdleTimeout(t *testing.T) {
74 c, err := net.Dial("tcp", testDaemon.addr)
75 if err != nil {
76 t.Fatal(err)
77 }
78 out, err := readPktline(c)
79 if err != nil && !errors.Is(err, io.EOF) {
80 t.Fatalf("expected nil, got error: %v", err)
81 }
82 if out != git.ErrTimeout.Error() && out != "" {
83 t.Fatalf("expected %q error, got %q", git.ErrTimeout, out)
84 }
85}
86
87func TestInvalidRepo(t *testing.T) {
88 c, err := net.Dial("tcp", testDaemon.addr)
89 if err != nil {
90 t.Fatal(err)
91 }
92 if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
93 t.Fatalf("expected nil, got error: %v", err)
94 }
95 out, err := readPktline(c)
96 if err != nil {
97 t.Fatalf("expected nil, got error: %v", err)
98 }
99 if out != git.ErrInvalidRepo.Error() {
100 t.Fatalf("expected %q error, got %q", git.ErrInvalidRepo, out)
101 }
102}
103
104func readPktline(c net.Conn) (string, error) {
105 buf, err := io.ReadAll(c)
106 if err != nil {
107 return "", err
108 }
109 pktout := pktline.NewScanner(bytes.NewReader(buf))
110 if !pktout.Scan() {
111 return "", pktout.Err()
112 }
113 return strings.TrimSpace(string(pktout.Bytes())), nil
114}