1package daemon
2
3import (
4 "context"
5 "fmt"
6 "log"
7 "net"
8 "os"
9 "strings"
10 "testing"
11 "time"
12
13 "github.com/charmbracelet/soft-serve/pkg/backend"
14 "github.com/charmbracelet/soft-serve/pkg/config"
15 "github.com/charmbracelet/soft-serve/pkg/db"
16 "github.com/charmbracelet/soft-serve/pkg/db/migrate"
17 "github.com/charmbracelet/soft-serve/pkg/git"
18 "github.com/charmbracelet/soft-serve/pkg/store"
19 "github.com/charmbracelet/soft-serve/pkg/store/database"
20 "github.com/charmbracelet/soft-serve/pkg/test"
21 "github.com/go-git/go-git/v5/plumbing/format/pktline"
22 _ "modernc.org/sqlite" // sqlite driver
23)
24
25var testDaemon *GitDaemon
26
27func TestMain(m *testing.M) {
28 tmp, err := os.MkdirTemp("", "soft-serve-test")
29 if err != nil {
30 log.Fatal(err)
31 }
32 defer os.RemoveAll(tmp)
33 ctx := context.TODO()
34 cfg := config.DefaultConfig()
35 cfg.DataPath = tmp
36 cfg.Git.MaxConnections = 3
37 cfg.Git.MaxTimeout = 100
38 cfg.Git.IdleTimeout = 1
39 cfg.Git.ListenAddr = fmt.Sprintf(":%d", test.RandomPort())
40 if err := cfg.Validate(); err != nil {
41 log.Fatal(err)
42 }
43 ctx = config.WithContext(ctx, cfg)
44 dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
45 if err != nil {
46 log.Fatal(err)
47 }
48 defer dbx.Close() // nolint: errcheck
49 if err := migrate.Migrate(ctx, dbx); err != nil {
50 log.Fatal(err)
51 }
52 datastore := database.New(ctx, dbx)
53 ctx = store.WithContext(ctx, datastore)
54 be := backend.New(ctx, cfg, dbx, datastore)
55 ctx = backend.WithContext(ctx, be)
56 d, err := NewGitDaemon(ctx)
57 if err != nil {
58 log.Fatal(err)
59 }
60 testDaemon = d
61 go func() {
62 if err := d.ListenAndServe(); err != ErrServerClosed {
63 log.Fatal(err)
64 }
65 }()
66 code := m.Run()
67 os.Unsetenv("SOFT_SERVE_DATA_PATH")
68 os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
69 os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
70 os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
71 os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
72 _ = d.Close()
73 _ = dbx.Close()
74 os.Exit(code)
75}
76
77func TestIdleTimeout(t *testing.T) {
78 var err error
79 var c net.Conn
80 var tries int
81 for {
82 c, err = net.Dial("tcp", testDaemon.addr)
83 if err != nil && tries >= 3 {
84 t.Fatal(err)
85 }
86 tries++
87 if testDaemon.conns.Size() != 0 {
88 break
89 }
90 time.Sleep(10 * time.Millisecond)
91 }
92 time.Sleep(2 * time.Second)
93 _, err = readPktline(c)
94 if err == nil {
95 t.Errorf("expected error, got nil")
96 }
97}
98
99func TestInvalidRepo(t *testing.T) {
100 c, err := net.Dial("tcp", testDaemon.addr)
101 if err != nil {
102 t.Fatal(err)
103 }
104 if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
105 t.Fatalf("expected nil, got error: %v", err)
106 }
107 _, err = readPktline(c)
108 if err != nil && err.Error() != git.ErrInvalidRepo.Error() {
109 t.Errorf("expected %q error, got %q", git.ErrInvalidRepo, err)
110 }
111}
112
113func readPktline(c net.Conn) (string, error) {
114 pktout := pktline.NewScanner(c)
115 if !pktout.Scan() {
116 return "", pktout.Err()
117 }
118 return strings.TrimSpace(string(pktout.Bytes())), nil
119}