daemon_test.go

 1package daemon
 2
 3import (
 4	"bytes"
 5	"io"
 6	"log"
 7	"net"
 8	"os"
 9	"strconv"
10	"testing"
11
12	"github.com/charmbracelet/soft-serve/server/config"
13	"github.com/charmbracelet/soft-serve/server/git"
14	"github.com/go-git/go-git/v5/plumbing/format/pktline"
15)
16
17var testDaemon *Daemon
18
19func TestMain(m *testing.M) {
20	tmp, err := os.MkdirTemp("", "soft-serve-test")
21	if err != nil {
22		log.Fatal(err)
23	}
24	defer os.RemoveAll(tmp)
25	os.Setenv("SOFT_SERVE_DATA_PATH", tmp)
26	os.Setenv("SOFT_SERVE_ANON_ACCESS", "read-only")
27	os.Setenv("SOFT_SERVE_GIT_MAX_CONNECTIONS", "3")
28	os.Setenv("SOFT_SERVE_GIT_MAX_TIMEOUT", "100")
29	os.Setenv("SOFT_SERVE_GIT_IDLE_TIMEOUT", "3")
30	os.Setenv("SOFT_SERVE_GIT_PORT", strconv.Itoa(randomPort()))
31	cfg := config.DefaultConfig()
32	d, err := NewDaemon(cfg)
33	if err != nil {
34		log.Fatal(err)
35	}
36	testDaemon = d
37	go func() {
38		if err := d.Start(); err != ErrServerClosed {
39			log.Fatal(err)
40		}
41	}()
42	defer d.Close()
43	os.Exit(m.Run())
44	os.Unsetenv("SOFT_SERVE_DATA_PATH")
45	os.Unsetenv("SOFT_SERVE_ANON_ACCESS")
46	os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
47	os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
48	os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
49	os.Unsetenv("SOFT_SERVE_GIT_PORT")
50}
51
52func TestIdleTimeout(t *testing.T) {
53	c, err := net.Dial("tcp", testDaemon.addr)
54	if err != nil {
55		t.Fatal(err)
56	}
57	out, err := readPktline(c)
58	if err != nil {
59		t.Fatalf("expected nil, got error: %v", err)
60	}
61	if out != git.ErrTimeout.Error() {
62		t.Fatalf("expected %q error, got %q", git.ErrTimeout, out)
63	}
64}
65
66func TestInvalidRepo(t *testing.T) {
67	c, err := net.Dial("tcp", testDaemon.addr)
68	if err != nil {
69		t.Fatal(err)
70	}
71	if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
72		t.Fatalf("expected nil, got error: %v", err)
73	}
74	out, err := readPktline(c)
75	if err != nil {
76		t.Fatalf("expected nil, got error: %v", err)
77	}
78	if out != git.ErrInvalidRepo.Error() {
79		t.Fatalf("expected %q error, got %q", git.ErrInvalidRepo, out)
80	}
81}
82
83func readPktline(c net.Conn) (string, error) {
84	buf, err := io.ReadAll(c)
85	if err != nil {
86		return "", err
87	}
88	pktout := pktline.NewScanner(bytes.NewReader(buf))
89	if !pktout.Scan() {
90		return "", pktout.Err()
91	}
92	return string(pktout.Bytes()), nil
93}
94
95func randomPort() int {
96	addr, _ := net.Listen("tcp", ":0") //nolint:gosec
97	_ = addr.Close()
98	return addr.Addr().(*net.TCPAddr).Port
99}