1package daemon
2
3import (
4 "bytes"
5 "context"
6 "fmt"
7 "io"
8 "log"
9 "net"
10 "os"
11 "strings"
12 "testing"
13
14 "github.com/charmbracelet/soft-serve/pkg/backend"
15 "github.com/charmbracelet/soft-serve/pkg/config"
16 "github.com/charmbracelet/soft-serve/pkg/db"
17 "github.com/charmbracelet/soft-serve/pkg/db/migrate"
18 "github.com/charmbracelet/soft-serve/pkg/git"
19 "github.com/charmbracelet/soft-serve/pkg/store"
20 "github.com/charmbracelet/soft-serve/pkg/store/database"
21 "github.com/charmbracelet/soft-serve/pkg/test"
22 "github.com/go-git/go-git/v5/plumbing/format/pktline"
23 _ "modernc.org/sqlite" // sqlite driver
24)
25
26var testDaemon *GitDaemon
27
28func TestMain(m *testing.M) {
29 tmp, err := os.MkdirTemp("", "soft-serve-test")
30 if err != nil {
31 log.Fatal(err)
32 }
33 defer os.RemoveAll(tmp)
34 ctx := context.TODO()
35 cfg := config.DefaultConfig()
36 cfg.DataPath = tmp
37 cfg.Git.MaxConnections = 3
38 cfg.Git.MaxTimeout = 100
39 cfg.Git.IdleTimeout = 1
40 cfg.Git.ListenAddr = fmt.Sprintf(":%d", test.RandomPort())
41 if err := cfg.Validate(); err != nil {
42 log.Fatal(err)
43 }
44 ctx = config.WithContext(ctx, cfg)
45 dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
46 if err != nil {
47 log.Fatal(err)
48 }
49 defer dbx.Close() // nolint: errcheck
50 if err := migrate.Migrate(ctx, dbx); err != nil {
51 log.Fatal(err)
52 }
53 datastore := database.New(ctx, dbx)
54 ctx = store.WithContext(ctx, datastore)
55 be := backend.New(ctx, cfg, dbx)
56 ctx = backend.WithContext(ctx, be)
57 d, err := NewGitDaemon(ctx)
58 if err != nil {
59 log.Fatal(err)
60 }
61 testDaemon = d
62 go func() {
63 if err := d.Start(); err != ErrServerClosed {
64 log.Fatal(err)
65 }
66 }()
67 code := m.Run()
68 os.Unsetenv("SOFT_SERVE_DATA_PATH")
69 os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
70 os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
71 os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
72 os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
73 _ = d.Close()
74 _ = dbx.Close()
75 os.Exit(code)
76}
77
78func TestIdleTimeout(t *testing.T) {
79 c, err := net.Dial("tcp", testDaemon.addr)
80 if err != nil {
81 t.Fatal(err)
82 }
83 _, err = readPktline(c)
84 if err != nil && err.Error() != git.ErrTimeout.Error() {
85 t.Fatalf("expected %q error, got %q", git.ErrTimeout, err)
86 }
87}
88
89func TestInvalidRepo(t *testing.T) {
90 c, err := net.Dial("tcp", testDaemon.addr)
91 if err != nil {
92 t.Fatal(err)
93 }
94 if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
95 t.Fatalf("expected nil, got error: %v", err)
96 }
97 _, err = readPktline(c)
98 if err != nil && err.Error() != git.ErrInvalidRepo.Error() {
99 t.Fatalf("expected %q error, got %q", git.ErrInvalidRepo, err)
100 }
101}
102
103func readPktline(c net.Conn) (string, error) {
104 buf, err := io.ReadAll(c)
105 if err != nil {
106 return "", err
107 }
108 pktout := pktline.NewScanner(bytes.NewReader(buf))
109 if !pktout.Scan() {
110 return "", pktout.Err()
111 }
112 return strings.TrimSpace(string(pktout.Bytes())), nil
113}