1package server
2
3import (
4 "bytes"
5 "errors"
6 "fmt"
7 "io"
8 "log"
9 "net"
10 "os"
11 "path/filepath"
12 "strings"
13 "testing"
14 "time"
15
16 "github.com/charmbracelet/soft-serve/server/backend/file"
17 "github.com/charmbracelet/soft-serve/server/config"
18 "github.com/go-git/go-git/v5/plumbing/format/pktline"
19)
20
21var testDaemon *GitDaemon
22
23func TestMain(m *testing.M) {
24 tmp, err := os.MkdirTemp("", "soft-serve-test")
25 if err != nil {
26 log.Fatal(err)
27 }
28 defer os.RemoveAll(tmp)
29 os.Setenv("SOFT_SERVE_DATA_PATH", tmp)
30 os.Setenv("SOFT_SERVE_GIT_MAX_CONNECTIONS", "3")
31 os.Setenv("SOFT_SERVE_GIT_MAX_TIMEOUT", "100")
32 os.Setenv("SOFT_SERVE_GIT_IDLE_TIMEOUT", "1")
33 os.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", randomPort()))
34 fb, err := file.NewFileBackend(filepath.Join(tmp, "repos"))
35 if err != nil {
36 log.Fatal(err)
37 }
38 cfg := config.DefaultConfig().WithBackend(fb).WithAccessMethod(fb)
39 d, err := NewGitDaemon(cfg)
40 if err != nil {
41 log.Fatal(err)
42 }
43 testDaemon = d
44 go func() {
45 if err := d.Start(); err != ErrServerClosed {
46 log.Fatal(err)
47 }
48 }()
49 code := m.Run()
50 os.Unsetenv("SOFT_SERVE_DATA_PATH")
51 os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
52 os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
53 os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
54 os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
55 _ = d.Close()
56 os.Exit(code)
57}
58
59func TestIdleTimeout(t *testing.T) {
60 c, err := net.Dial("tcp", testDaemon.addr)
61 if err != nil {
62 t.Fatal(err)
63 }
64 time.Sleep(2 * time.Second)
65 out, err := readPktline(c)
66 if err != nil && !errors.Is(err, io.EOF) {
67 t.Fatalf("expected nil, got error: %v", err)
68 }
69 if out != ErrTimeout.Error() {
70 t.Fatalf("expected %q error, got %q", ErrTimeout, out)
71 }
72}
73
74func TestInvalidRepo(t *testing.T) {
75 c, err := net.Dial("tcp", testDaemon.addr)
76 if err != nil {
77 t.Fatal(err)
78 }
79 if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
80 t.Fatalf("expected nil, got error: %v", err)
81 }
82 out, err := readPktline(c)
83 if err != nil {
84 t.Fatalf("expected nil, got error: %v", err)
85 }
86 if out != ErrInvalidRepo.Error() {
87 t.Fatalf("expected %q error, got %q", ErrInvalidRepo, out)
88 }
89}
90
91func readPktline(c net.Conn) (string, error) {
92 buf, err := io.ReadAll(c)
93 if err != nil {
94 return "", err
95 }
96 pktout := pktline.NewScanner(bytes.NewReader(buf))
97 if !pktout.Scan() {
98 return "", pktout.Err()
99 }
100 return strings.TrimSpace(string(pktout.Bytes())), nil
101}