1package daemon
2
3import (
4 "bytes"
5 "context"
6 "errors"
7 "fmt"
8 "io"
9 "log"
10 "net"
11 "os"
12 "strings"
13 "testing"
14
15 "github.com/charmbracelet/soft-serve/server/backend"
16 "github.com/charmbracelet/soft-serve/server/config"
17 "github.com/charmbracelet/soft-serve/server/db"
18 "github.com/charmbracelet/soft-serve/server/db/migrate"
19 "github.com/charmbracelet/soft-serve/server/git"
20 "github.com/charmbracelet/soft-serve/server/test"
21 "github.com/go-git/go-git/v5/plumbing/format/pktline"
22 _ "modernc.org/sqlite" // sqlite driver
23)
24
25var testDaemon *GitDaemon
26
27func TestMain(m *testing.M) {
28 tmp, err := os.MkdirTemp("", "soft-serve-test")
29 if err != nil {
30 log.Fatal(err)
31 }
32 defer os.RemoveAll(tmp)
33 os.Setenv("SOFT_SERVE_DATA_PATH", tmp)
34 os.Setenv("SOFT_SERVE_GIT_MAX_CONNECTIONS", "3")
35 os.Setenv("SOFT_SERVE_GIT_MAX_TIMEOUT", "100")
36 os.Setenv("SOFT_SERVE_GIT_IDLE_TIMEOUT", "1")
37 os.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", test.RandomPort()))
38 ctx := context.TODO()
39 cfg := config.DefaultConfig()
40 if err := cfg.Validate(); err != nil {
41 log.Fatal(err)
42 }
43 ctx = config.WithContext(ctx, cfg)
44 db, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
45 if err != nil {
46 log.Fatal(err)
47 }
48 defer db.Close() // nolint: errcheck
49 if err := migrate.Migrate(ctx, db); err != nil {
50 log.Fatal(err)
51 }
52 be := backend.New(ctx, cfg, db)
53 ctx = backend.WithContext(ctx, be)
54 d, err := NewGitDaemon(ctx)
55 if err != nil {
56 log.Fatal(err)
57 }
58 testDaemon = d
59 go func() {
60 if err := d.Start(); err != ErrServerClosed {
61 log.Fatal(err)
62 }
63 }()
64 code := m.Run()
65 os.Unsetenv("SOFT_SERVE_DATA_PATH")
66 os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
67 os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
68 os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
69 os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
70 _ = d.Close()
71 _ = db.Close()
72 os.Exit(code)
73}
74
75func TestIdleTimeout(t *testing.T) {
76 c, err := net.Dial("tcp", testDaemon.addr)
77 if err != nil {
78 t.Fatal(err)
79 }
80 out, err := readPktline(c)
81 if err != nil && !errors.Is(err, io.EOF) {
82 t.Fatalf("expected nil, got error: %v", err)
83 }
84 if out != "ERR "+git.ErrTimeout.Error() && out != "" {
85 t.Fatalf("expected %q error, got %q", git.ErrTimeout, out)
86 }
87}
88
89func TestInvalidRepo(t *testing.T) {
90 c, err := net.Dial("tcp", testDaemon.addr)
91 if err != nil {
92 t.Fatal(err)
93 }
94 if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
95 t.Fatalf("expected nil, got error: %v", err)
96 }
97 out, err := readPktline(c)
98 if err != nil {
99 t.Fatalf("expected nil, got error: %v", err)
100 }
101 if out != "ERR "+git.ErrInvalidRepo.Error() {
102 t.Fatalf("expected %q error, got %q", git.ErrInvalidRepo, out)
103 }
104}
105
106func readPktline(c net.Conn) (string, error) {
107 buf, err := io.ReadAll(c)
108 if err != nil {
109 return "", err
110 }
111 pktout := pktline.NewScanner(bytes.NewReader(buf))
112 if !pktout.Scan() {
113 return "", pktout.Err()
114 }
115 return strings.TrimSpace(string(pktout.Bytes())), nil
116}