1package daemon
2
3import (
4 "bytes"
5 "context"
6 "errors"
7 "fmt"
8 "io"
9 "log"
10 "net"
11 "os"
12 "strings"
13 "testing"
14
15 "github.com/charmbracelet/soft-serve/server/backend"
16 "github.com/charmbracelet/soft-serve/server/config"
17 "github.com/charmbracelet/soft-serve/server/db"
18 "github.com/charmbracelet/soft-serve/server/db/migrate"
19 "github.com/charmbracelet/soft-serve/server/git"
20 "github.com/charmbracelet/soft-serve/server/store"
21 "github.com/charmbracelet/soft-serve/server/store/database"
22 "github.com/charmbracelet/soft-serve/server/test"
23 "github.com/go-git/go-git/v5/plumbing/format/pktline"
24 _ "modernc.org/sqlite" // sqlite driver
25)
26
27var testDaemon *GitDaemon
28
29func TestMain(m *testing.M) {
30 tmp, err := os.MkdirTemp("", "soft-serve-test")
31 if err != nil {
32 log.Fatal(err)
33 }
34 defer os.RemoveAll(tmp)
35 ctx := context.TODO()
36 cfg := config.DefaultConfig()
37 cfg.DataPath = tmp
38 cfg.Git.MaxConnections = 3
39 cfg.Git.MaxTimeout = 100
40 cfg.Git.IdleTimeout = 1
41 cfg.Git.ListenAddr = fmt.Sprintf(":%d", test.RandomPort())
42 if err := cfg.Validate(); err != nil {
43 log.Fatal(err)
44 }
45 ctx = config.WithContext(ctx, cfg)
46 dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
47 if err != nil {
48 log.Fatal(err)
49 }
50 defer dbx.Close() // nolint: errcheck
51 if err := migrate.Migrate(ctx, dbx); err != nil {
52 log.Fatal(err)
53 }
54 datastore := database.New(ctx, dbx)
55 ctx = store.WithContext(ctx, datastore)
56 be := backend.New(ctx, cfg, dbx)
57 ctx = backend.WithContext(ctx, be)
58 d, err := NewGitDaemon(ctx)
59 if err != nil {
60 log.Fatal(err)
61 }
62 testDaemon = d
63 go func() {
64 if err := d.Start(); err != ErrServerClosed {
65 log.Fatal(err)
66 }
67 }()
68 code := m.Run()
69 os.Unsetenv("SOFT_SERVE_DATA_PATH")
70 os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
71 os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
72 os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
73 os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
74 _ = d.Close()
75 _ = dbx.Close()
76 os.Exit(code)
77}
78
79func TestIdleTimeout(t *testing.T) {
80 c, err := net.Dial("tcp", testDaemon.addr)
81 if err != nil {
82 t.Fatal(err)
83 }
84 out, err := readPktline(c)
85 if err != nil && !errors.Is(err, io.EOF) {
86 t.Fatalf("expected nil, got error: %v", err)
87 }
88 if out != "ERR "+git.ErrTimeout.Error() && out != "" {
89 t.Fatalf("expected %q error, got %q", git.ErrTimeout, out)
90 }
91}
92
93func TestInvalidRepo(t *testing.T) {
94 c, err := net.Dial("tcp", testDaemon.addr)
95 if err != nil {
96 t.Fatal(err)
97 }
98 if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
99 t.Fatalf("expected nil, got error: %v", err)
100 }
101 out, err := readPktline(c)
102 if err != nil {
103 t.Fatalf("expected nil, got error: %v", err)
104 }
105 if out != "ERR "+git.ErrInvalidRepo.Error() {
106 t.Fatalf("expected %q error, got %q", git.ErrInvalidRepo, out)
107 }
108}
109
110func readPktline(c net.Conn) (string, error) {
111 buf, err := io.ReadAll(c)
112 if err != nil {
113 return "", err
114 }
115 pktout := pktline.NewScanner(bytes.NewReader(buf))
116 if !pktout.Scan() {
117 return "", pktout.Err()
118 }
119 return strings.TrimSpace(string(pktout.Bytes())), nil
120}