1package daemon
2
3import (
4 "context"
5 "fmt"
6 "log"
7 "net"
8 "os"
9 "strings"
10 "testing"
11 "time"
12
13 "github.com/charmbracelet/soft-serve/pkg/backend"
14 "github.com/charmbracelet/soft-serve/pkg/config"
15 "github.com/charmbracelet/soft-serve/pkg/db"
16 "github.com/charmbracelet/soft-serve/pkg/db/migrate"
17 "github.com/charmbracelet/soft-serve/pkg/git"
18 "github.com/charmbracelet/soft-serve/pkg/store"
19 "github.com/charmbracelet/soft-serve/pkg/store/database"
20 "github.com/charmbracelet/soft-serve/pkg/test"
21 "github.com/go-git/go-git/v5/plumbing/format/pktline"
22 _ "modernc.org/sqlite" // sqlite driver
23)
24
25var testDaemon *GitDaemon
26
27func TestMain(m *testing.M) {
28 tmp, err := os.MkdirTemp("", "soft-serve-test")
29 if err != nil {
30 log.Fatal(err)
31 }
32 defer os.RemoveAll(tmp)
33 ctx := context.TODO()
34 cfg := config.DefaultConfig()
35 cfg.DataPath = tmp
36 cfg.Git.MaxConnections = 3
37 cfg.Git.MaxTimeout = 100
38 cfg.Git.IdleTimeout = 1
39 cfg.Git.ListenAddr = fmt.Sprintf(":%d", test.RandomPort())
40 if err := cfg.Validate(); err != nil {
41 log.Fatal(err)
42 }
43 ctx = config.WithContext(ctx, cfg)
44 dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
45 if err != nil {
46 log.Fatal(err)
47 }
48 defer dbx.Close() //nolint: errcheck
49 if err := migrate.Migrate(ctx, dbx); err != nil {
50 log.Fatal(err)
51 }
52 datastore := database.New(ctx, dbx)
53 ctx = store.WithContext(ctx, datastore)
54 be := backend.New(ctx, cfg, dbx, datastore)
55 ctx = backend.WithContext(ctx, be)
56 d, err := NewGitDaemon(ctx)
57 if err != nil {
58 log.Fatal(err)
59 }
60 testDaemon = d
61 go d.ListenAndServe() //nolint:errcheck
62 code := m.Run()
63 os.Unsetenv("SOFT_SERVE_DATA_PATH")
64 os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
65 os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
66 os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
67 os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
68 _ = d.Close()
69 _ = dbx.Close()
70 os.Exit(code)
71}
72
73func TestIdleTimeout(t *testing.T) {
74 var err error
75 var c net.Conn
76 var tries int
77 for {
78 c, err = net.Dial("tcp", testDaemon.addr)
79 if err != nil && tries >= 3 {
80 t.Fatalf("failed to connect to daemon after %d tries: %v", tries, err)
81 }
82 tries++
83 if testDaemon.conns.Size() != 0 {
84 break
85 }
86 time.Sleep(10 * time.Millisecond)
87 }
88 time.Sleep(2 * time.Second)
89 _, err = readPktline(c)
90 if err == nil {
91 t.Errorf("expected error, got nil")
92 }
93}
94
95func TestInvalidRepo(t *testing.T) {
96 c, err := net.Dial("tcp", testDaemon.addr)
97 if err != nil {
98 t.Fatalf("failed to connect to daemon: %v", err)
99 }
100 if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
101 t.Fatalf("expected nil, got error: %v", err)
102 }
103 _, err = readPktline(c)
104 if err != nil && err.Error() != git.ErrInvalidRepo.Error() {
105 t.Errorf("expected %q error, got %q", git.ErrInvalidRepo, err)
106 }
107}
108
109func readPktline(c net.Conn) (string, error) {
110 pktout := pktline.NewScanner(c)
111 if !pktout.Scan() {
112 return "", pktout.Err()
113 }
114 return strings.TrimSpace(string(pktout.Bytes())), nil
115}