1package server
2
3import (
4 "bytes"
5 "context"
6 "errors"
7 "fmt"
8 "io"
9 "log"
10 "net"
11 "os"
12 "strings"
13 "testing"
14
15 "github.com/charmbracelet/soft-serve/server/backend/sqlite"
16 "github.com/charmbracelet/soft-serve/server/config"
17 "github.com/go-git/go-git/v5/plumbing/format/pktline"
18)
19
20var testDaemon *GitDaemon
21
22func TestMain(m *testing.M) {
23 tmp, err := os.MkdirTemp("", "soft-serve-test")
24 if err != nil {
25 log.Fatal(err)
26 }
27 defer os.RemoveAll(tmp)
28 os.Setenv("SOFT_SERVE_DATA_PATH", tmp)
29 os.Setenv("SOFT_SERVE_GIT_MAX_CONNECTIONS", "3")
30 os.Setenv("SOFT_SERVE_GIT_MAX_TIMEOUT", "100")
31 os.Setenv("SOFT_SERVE_GIT_IDLE_TIMEOUT", "1")
32 os.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", randomPort()))
33 cfg := config.DefaultConfig()
34 d, err := NewGitDaemon(cfg)
35 if err != nil {
36 log.Fatal(err)
37 }
38 ctx := context.TODO()
39 fb, err := sqlite.NewSqliteBackend(ctx, cfg)
40 if err != nil {
41 log.Fatal(err)
42 }
43 cfg = cfg.WithBackend(fb)
44 testDaemon = d
45 go func() {
46 if err := d.Start(); err != ErrServerClosed {
47 log.Fatal(err)
48 }
49 }()
50 code := m.Run()
51 os.Unsetenv("SOFT_SERVE_DATA_PATH")
52 os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
53 os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
54 os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
55 os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
56 _ = d.Close()
57 _ = fb.Close()
58 os.Exit(code)
59}
60
61func TestIdleTimeout(t *testing.T) {
62 c, err := net.Dial("tcp", testDaemon.addr)
63 if err != nil {
64 t.Fatal(err)
65 }
66 out, err := readPktline(c)
67 if err != nil && !errors.Is(err, io.EOF) {
68 t.Fatalf("expected nil, got error: %v", err)
69 }
70 if out != ErrTimeout.Error() || out == "" {
71 t.Fatalf("expected %q error, got %q", ErrTimeout, out)
72 }
73}
74
75func TestInvalidRepo(t *testing.T) {
76 c, err := net.Dial("tcp", testDaemon.addr)
77 if err != nil {
78 t.Fatal(err)
79 }
80 if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
81 t.Fatalf("expected nil, got error: %v", err)
82 }
83 out, err := readPktline(c)
84 if err != nil {
85 t.Fatalf("expected nil, got error: %v", err)
86 }
87 if out != ErrInvalidRepo.Error() {
88 t.Fatalf("expected %q error, got %q", ErrInvalidRepo, out)
89 }
90}
91
92func readPktline(c net.Conn) (string, error) {
93 buf, err := io.ReadAll(c)
94 if err != nil {
95 return "", err
96 }
97 pktout := pktline.NewScanner(bytes.NewReader(buf))
98 if !pktout.Scan() {
99 return "", pktout.Err()
100 }
101 return strings.TrimSpace(string(pktout.Bytes())), nil
102}