1package daemon
2
3import (
4 "bytes"
5 "context"
6 "errors"
7 "fmt"
8 "io"
9 "log"
10 "net"
11 "os"
12 "strings"
13 "testing"
14
15 "github.com/charmbracelet/soft-serve/server/backend"
16 "github.com/charmbracelet/soft-serve/server/backend/sqlite"
17 "github.com/charmbracelet/soft-serve/server/config"
18 "github.com/charmbracelet/soft-serve/server/git"
19 "github.com/charmbracelet/soft-serve/server/test"
20 "github.com/go-git/go-git/v5/plumbing/format/pktline"
21)
22
23var testDaemon *GitDaemon
24
25func TestMain(m *testing.M) {
26 tmp, err := os.MkdirTemp("", "soft-serve-test")
27 if err != nil {
28 log.Fatal(err)
29 }
30 defer os.RemoveAll(tmp)
31 os.Setenv("SOFT_SERVE_DATA_PATH", tmp)
32 os.Setenv("SOFT_SERVE_GIT_MAX_CONNECTIONS", "3")
33 os.Setenv("SOFT_SERVE_GIT_MAX_TIMEOUT", "100")
34 os.Setenv("SOFT_SERVE_GIT_IDLE_TIMEOUT", "1")
35 os.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", test.RandomPort()))
36 ctx := context.TODO()
37 cfg := config.DefaultConfig()
38 if err := cfg.WriteConfig(); err != nil {
39 log.Fatal("failed to write default config: %w", err)
40 }
41 ctx = config.WithContext(ctx, cfg)
42 fb, err := sqlite.NewSqliteBackend(ctx)
43 if err != nil {
44 log.Fatal(err)
45 }
46 cfg = cfg.WithBackend(fb)
47 ctx = backend.WithContext(ctx, fb)
48 d, err := NewGitDaemon(ctx)
49 if err != nil {
50 log.Fatal(err)
51 }
52 testDaemon = d
53 go func() {
54 if err := d.Start(); err != ErrServerClosed {
55 log.Fatal(err)
56 }
57 }()
58 code := m.Run()
59 os.Unsetenv("SOFT_SERVE_DATA_PATH")
60 os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
61 os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
62 os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
63 os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
64 _ = d.Close()
65 _ = fb.Close()
66 os.Exit(code)
67}
68
69func TestIdleTimeout(t *testing.T) {
70 c, err := net.Dial("tcp", testDaemon.addr)
71 if err != nil {
72 t.Fatal(err)
73 }
74 out, err := readPktline(c)
75 if err != nil && !errors.Is(err, io.EOF) {
76 t.Fatalf("expected nil, got error: %v", err)
77 }
78 if out != git.ErrTimeout.Error() && out != "" {
79 t.Fatalf("expected %q error, got %q", git.ErrTimeout, out)
80 }
81}
82
83func TestInvalidRepo(t *testing.T) {
84 c, err := net.Dial("tcp", testDaemon.addr)
85 if err != nil {
86 t.Fatal(err)
87 }
88 if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
89 t.Fatalf("expected nil, got error: %v", err)
90 }
91 out, err := readPktline(c)
92 if err != nil {
93 t.Fatalf("expected nil, got error: %v", err)
94 }
95 if out != git.ErrInvalidRepo.Error() {
96 t.Fatalf("expected %q error, got %q", git.ErrInvalidRepo, out)
97 }
98}
99
100func readPktline(c net.Conn) (string, error) {
101 buf, err := io.ReadAll(c)
102 if err != nil {
103 return "", err
104 }
105 pktout := pktline.NewScanner(bytes.NewReader(buf))
106 if !pktout.Scan() {
107 return "", pktout.Err()
108 }
109 return strings.TrimSpace(string(pktout.Bytes())), nil
110}