1package server
2
3import (
4 "bytes"
5 "errors"
6 "fmt"
7 "io"
8 "log"
9 "net"
10 "os"
11 "strings"
12 "testing"
13
14 "github.com/charmbracelet/soft-serve/server/backend/sqlite"
15 "github.com/charmbracelet/soft-serve/server/config"
16 "github.com/go-git/go-git/v5/plumbing/format/pktline"
17)
18
19var testDaemon *GitDaemon
20
21func TestMain(m *testing.M) {
22 tmp, err := os.MkdirTemp("", "soft-serve-test")
23 if err != nil {
24 log.Fatal(err)
25 }
26 defer os.RemoveAll(tmp)
27 os.Setenv("SOFT_SERVE_DATA_PATH", tmp)
28 os.Setenv("SOFT_SERVE_GIT_MAX_CONNECTIONS", "3")
29 os.Setenv("SOFT_SERVE_GIT_MAX_TIMEOUT", "100")
30 os.Setenv("SOFT_SERVE_GIT_IDLE_TIMEOUT", "1")
31 os.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", randomPort()))
32 cfg := config.DefaultConfig()
33 d, err := NewGitDaemon(cfg)
34 if err != nil {
35 log.Fatal(err)
36 }
37 fb, err := sqlite.NewSqliteBackend(cfg)
38 if err != nil {
39 log.Fatal(err)
40 }
41 cfg = cfg.WithBackend(fb)
42 testDaemon = d
43 go func() {
44 if err := d.Start(); err != ErrServerClosed {
45 log.Fatal(err)
46 }
47 }()
48 code := m.Run()
49 os.Unsetenv("SOFT_SERVE_DATA_PATH")
50 os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
51 os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
52 os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
53 os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
54 _ = d.Close()
55 _ = fb.Close()
56 os.Exit(code)
57}
58
59func TestIdleTimeout(t *testing.T) {
60 c, err := net.Dial("tcp", testDaemon.addr)
61 if err != nil {
62 t.Fatal(err)
63 }
64 out, err := readPktline(c)
65 if err != nil && !errors.Is(err, io.EOF) {
66 t.Fatalf("expected nil, got error: %v", err)
67 }
68 if out != ErrTimeout.Error() || out == "" {
69 t.Fatalf("expected %q error, got %q", ErrTimeout, out)
70 }
71}
72
73func TestInvalidRepo(t *testing.T) {
74 c, err := net.Dial("tcp", testDaemon.addr)
75 if err != nil {
76 t.Fatal(err)
77 }
78 if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
79 t.Fatalf("expected nil, got error: %v", err)
80 }
81 out, err := readPktline(c)
82 if err != nil {
83 t.Fatalf("expected nil, got error: %v", err)
84 }
85 if out != ErrInvalidRepo.Error() {
86 t.Fatalf("expected %q error, got %q", ErrInvalidRepo, out)
87 }
88}
89
90func readPktline(c net.Conn) (string, error) {
91 buf, err := io.ReadAll(c)
92 if err != nil {
93 return "", err
94 }
95 pktout := pktline.NewScanner(bytes.NewReader(buf))
96 if !pktout.Scan() {
97 return "", pktout.Err()
98 }
99 return strings.TrimSpace(string(pktout.Bytes())), nil
100}