1package daemon
2
3import (
4 "context"
5 "fmt"
6 "log"
7 "net"
8 "os"
9 "strings"
10 "testing"
11
12 "github.com/charmbracelet/soft-serve/pkg/backend"
13 "github.com/charmbracelet/soft-serve/pkg/config"
14 "github.com/charmbracelet/soft-serve/pkg/db"
15 "github.com/charmbracelet/soft-serve/pkg/db/migrate"
16 "github.com/charmbracelet/soft-serve/pkg/git"
17 "github.com/charmbracelet/soft-serve/pkg/store"
18 "github.com/charmbracelet/soft-serve/pkg/store/database"
19 "github.com/charmbracelet/soft-serve/pkg/test"
20 "github.com/go-git/go-git/v5/plumbing/format/pktline"
21 _ "modernc.org/sqlite" // sqlite driver
22)
23
24var testDaemon *GitDaemon
25
26func TestMain(m *testing.M) {
27 tmp, err := os.MkdirTemp("", "soft-serve-test")
28 if err != nil {
29 log.Fatal(err)
30 }
31 defer os.RemoveAll(tmp)
32 ctx := context.TODO()
33 cfg := config.DefaultConfig()
34 cfg.DataPath = tmp
35 cfg.Git.MaxConnections = 3
36 cfg.Git.MaxTimeout = 100
37 cfg.Git.IdleTimeout = 1
38 cfg.Git.ListenAddr = fmt.Sprintf(":%d", test.RandomPort())
39 if err := cfg.Validate(); err != nil {
40 log.Fatal(err)
41 }
42 ctx = config.WithContext(ctx, cfg)
43 dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
44 if err != nil {
45 log.Fatal(err)
46 }
47 defer dbx.Close() // nolint: errcheck
48 if err := migrate.Migrate(ctx, dbx); err != nil {
49 log.Fatal(err)
50 }
51 datastore := database.New(ctx, dbx)
52 ctx = store.WithContext(ctx, datastore)
53 be := backend.New(ctx, cfg, dbx)
54 ctx = backend.WithContext(ctx, be)
55 d, err := NewGitDaemon(ctx)
56 if err != nil {
57 log.Fatal(err)
58 }
59 testDaemon = d
60 go func() {
61 if err := d.Start(); err != ErrServerClosed {
62 log.Fatal(err)
63 }
64 }()
65 code := m.Run()
66 os.Unsetenv("SOFT_SERVE_DATA_PATH")
67 os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
68 os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
69 os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
70 os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
71 _ = d.Close()
72 _ = dbx.Close()
73 os.Exit(code)
74}
75
76func TestIdleTimeout(t *testing.T) {
77 c, err := net.Dial("tcp", testDaemon.addr)
78 if err != nil {
79 t.Fatal(err)
80 }
81 _, err = readPktline(c)
82 if err != nil && err.Error() != git.ErrTimeout.Error() {
83 t.Fatalf("expected %q error, got %q", git.ErrTimeout, err)
84 }
85}
86
87func TestInvalidRepo(t *testing.T) {
88 c, err := net.Dial("tcp", testDaemon.addr)
89 if err != nil {
90 t.Fatal(err)
91 }
92 if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
93 t.Fatalf("expected nil, got error: %v", err)
94 }
95 _, err = readPktline(c)
96 if err != nil && err.Error() != git.ErrInvalidRepo.Error() {
97 t.Fatalf("expected %q error, got %q", git.ErrInvalidRepo, err)
98 }
99}
100
101func readPktline(c net.Conn) (string, error) {
102 pktout := pktline.NewScanner(c)
103 if !pktout.Scan() {
104 return "", pktout.Err()
105 }
106 return strings.TrimSpace(string(pktout.Bytes())), nil
107}