1package daemon
2
3import (
4 "context"
5 "fmt"
6 "log"
7 "net"
8 "os"
9 "testing"
10 "time"
11
12 "github.com/charmbracelet/soft-serve/pkg/backend"
13 "github.com/charmbracelet/soft-serve/pkg/config"
14 "github.com/charmbracelet/soft-serve/pkg/db"
15 "github.com/charmbracelet/soft-serve/pkg/db/migrate"
16 "github.com/charmbracelet/soft-serve/pkg/git"
17 "github.com/charmbracelet/soft-serve/pkg/store"
18 "github.com/charmbracelet/soft-serve/pkg/store/database"
19 "github.com/charmbracelet/soft-serve/pkg/test"
20 "github.com/go-git/go-git/v5/plumbing/format/pktline"
21 _ "modernc.org/sqlite" // sqlite driver
22)
23
24var testDaemon *GitDaemon
25
26func TestMain(m *testing.M) {
27 tmp, err := os.MkdirTemp("", "soft-serve-test")
28 if err != nil {
29 log.Fatal(err)
30 }
31 defer os.RemoveAll(tmp)
32 ctx := context.TODO()
33 cfg := config.DefaultConfig()
34 cfg.DataPath = tmp
35 cfg.Git.MaxConnections = 3
36 cfg.Git.MaxTimeout = 100
37 cfg.Git.IdleTimeout = 1
38 cfg.Git.ListenAddr = fmt.Sprintf(":%d", test.RandomPort())
39 if err := cfg.Validate(); err != nil {
40 log.Fatal(err)
41 }
42 ctx = config.WithContext(ctx, cfg)
43 dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
44 if err != nil {
45 log.Fatal(err)
46 }
47 defer dbx.Close() //nolint: errcheck
48 if err := migrate.Migrate(ctx, dbx); err != nil {
49 log.Fatal(err)
50 }
51 datastore := database.New(ctx, dbx)
52 ctx = store.WithContext(ctx, datastore)
53 be := backend.New(ctx, cfg, dbx, datastore)
54 ctx = backend.WithContext(ctx, be)
55 d, err := NewGitDaemon(ctx)
56 if err != nil {
57 log.Fatal(err)
58 }
59 testDaemon = d
60 go d.ListenAndServe() //nolint:errcheck
61 code := m.Run()
62 os.Unsetenv("SOFT_SERVE_DATA_PATH")
63 os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
64 os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
65 os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
66 os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
67 _ = d.Close()
68 _ = dbx.Close()
69 os.Exit(code)
70}
71
72func TestIdleTimeout(t *testing.T) {
73 var err error
74 var c net.Conn
75 var tries int
76 for {
77 c, err = net.Dial("tcp", testDaemon.addr)
78 if err != nil && tries >= 3 {
79 t.Fatalf("failed to connect to daemon after %d tries: %v", tries, err)
80 }
81 tries++
82 if testDaemon.conns.Size() != 0 {
83 break
84 }
85 time.Sleep(10 * time.Millisecond)
86 }
87 time.Sleep(2 * time.Second)
88 err = readPktline(c)
89 if err == nil {
90 t.Errorf("expected error, got nil")
91 }
92}
93
94func TestInvalidRepo(t *testing.T) {
95 c, err := net.Dial("tcp", testDaemon.addr)
96 if err != nil {
97 t.Fatalf("failed to connect to daemon: %v", err)
98 }
99 if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
100 t.Fatalf("expected nil, got error: %v", err)
101 }
102 err = readPktline(c)
103 if err != nil && err.Error() != git.ErrInvalidRepo.Error() {
104 t.Errorf("expected %q error, got %q", git.ErrInvalidRepo, err)
105 }
106}
107
108func readPktline(c net.Conn) error {
109 pktout := pktline.NewScanner(c)
110 if !pktout.Scan() {
111 return pktout.Err()
112 }
113 return nil
114}