1package daemon
2
3import (
4 "bytes"
5 "context"
6 "errors"
7 "fmt"
8 "io"
9 "log"
10 "net"
11 "os"
12 "strings"
13 "testing"
14
15 "github.com/charmbracelet/soft-serve/server/backend/sqlite"
16 "github.com/charmbracelet/soft-serve/server/config"
17 "github.com/charmbracelet/soft-serve/server/git"
18 "github.com/charmbracelet/soft-serve/server/test"
19 "github.com/go-git/go-git/v5/plumbing/format/pktline"
20)
21
22var testDaemon *GitDaemon
23
24func TestMain(m *testing.M) {
25 tmp, err := os.MkdirTemp("", "soft-serve-test")
26 if err != nil {
27 log.Fatal(err)
28 }
29 defer os.RemoveAll(tmp)
30 os.Setenv("SOFT_SERVE_DATA_PATH", tmp)
31 os.Setenv("SOFT_SERVE_GIT_MAX_CONNECTIONS", "3")
32 os.Setenv("SOFT_SERVE_GIT_MAX_TIMEOUT", "100")
33 os.Setenv("SOFT_SERVE_GIT_IDLE_TIMEOUT", "1")
34 os.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", test.RandomPort()))
35 ctx := context.TODO()
36 cfg := config.DefaultConfig()
37 ctx = config.WithContext(ctx, cfg)
38 d, err := NewGitDaemon(ctx)
39 if err != nil {
40 log.Fatal(err)
41 }
42 fb, err := sqlite.NewSqliteBackend(ctx)
43 if err != nil {
44 log.Fatal(err)
45 }
46 cfg = cfg.WithBackend(fb)
47 testDaemon = d
48 go func() {
49 if err := d.Start(); err != ErrServerClosed {
50 log.Fatal(err)
51 }
52 }()
53 code := m.Run()
54 os.Unsetenv("SOFT_SERVE_DATA_PATH")
55 os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
56 os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
57 os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
58 os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
59 _ = d.Close()
60 _ = fb.Close()
61 os.Exit(code)
62}
63
64func TestIdleTimeout(t *testing.T) {
65 c, err := net.Dial("tcp", testDaemon.addr)
66 if err != nil {
67 t.Fatal(err)
68 }
69 out, err := readPktline(c)
70 if err != nil && !errors.Is(err, io.EOF) {
71 t.Fatalf("expected nil, got error: %v", err)
72 }
73 if out != git.ErrTimeout.Error() && out != "" {
74 t.Fatalf("expected %q error, got %q", git.ErrTimeout, out)
75 }
76}
77
78func TestInvalidRepo(t *testing.T) {
79 c, err := net.Dial("tcp", testDaemon.addr)
80 if err != nil {
81 t.Fatal(err)
82 }
83 if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
84 t.Fatalf("expected nil, got error: %v", err)
85 }
86 out, err := readPktline(c)
87 if err != nil {
88 t.Fatalf("expected nil, got error: %v", err)
89 }
90 if out != git.ErrInvalidRepo.Error() {
91 t.Fatalf("expected %q error, got %q", git.ErrInvalidRepo, out)
92 }
93}
94
95func readPktline(c net.Conn) (string, error) {
96 buf, err := io.ReadAll(c)
97 if err != nil {
98 return "", err
99 }
100 pktout := pktline.NewScanner(bytes.NewReader(buf))
101 if !pktout.Scan() {
102 return "", pktout.Err()
103 }
104 return strings.TrimSpace(string(pktout.Bytes())), nil
105}