daemon_test.go

 1package daemon
 2
 3import (
 4	"bytes"
 5	"io"
 6	"log"
 7	"net"
 8	"os"
 9	"testing"
10
11	"github.com/charmbracelet/soft-serve/server/config"
12	"github.com/charmbracelet/soft-serve/server/git"
13	"github.com/go-git/go-git/v5/plumbing/format/pktline"
14)
15
16var testDaemon *Daemon
17
18func TestMain(m *testing.M) {
19	tmp, err := os.MkdirTemp("", "soft-serve-test")
20	if err != nil {
21		log.Fatal(err)
22	}
23	defer os.RemoveAll(tmp)
24	cfg := &config.Config{
25		Host:     "",
26		DataPath: tmp,
27		Git: config.GitConfig{
28			// Reduce the max timeout to 100 second so we can test the timeout.
29			MaxTimeout: 100,
30			// Reduce the max read timeout to 1 second so we can test the timeout.
31			MaxReadTimeout: 1,
32			// Reduce the max connections to 3 so we can test the timeout.
33			MaxConnections: 3,
34			Port:           9418,
35		},
36	}
37	d, err := NewDaemon(cfg)
38	if err != nil {
39		log.Fatal(err)
40	}
41	testDaemon = d
42	go func() {
43		if err := d.Start(); err != ErrServerClosed {
44			log.Fatal(err)
45		}
46	}()
47	defer d.Close()
48	os.Exit(m.Run())
49}
50
51func TestMaxReadTimeout(t *testing.T) {
52	c, err := net.Dial("tcp", testDaemon.addr)
53	if err != nil {
54		t.Fatal(err)
55	}
56	out, err := readPktline(c)
57	if err != nil {
58		t.Fatalf("expected nil, got error: %v", err)
59	}
60	if out != git.ErrMaxTimeout.Error() {
61		t.Fatalf("expected %q error, got nil", git.ErrMaxTimeout)
62	}
63}
64
65func TestInvalidRepo(t *testing.T) {
66	c, err := net.Dial("tcp", testDaemon.addr)
67	if err != nil {
68		t.Fatal(err)
69	}
70	if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
71		t.Fatalf("expected nil, got error: %v", err)
72	}
73	out, err := readPktline(c)
74	if err != nil {
75		t.Fatalf("expected nil, got error: %v", err)
76	}
77	if out != git.ErrInvalidRepo.Error() {
78		t.Fatalf("expected %q error, got nil", git.ErrInvalidRepo)
79	}
80}
81
82func readPktline(c net.Conn) (string, error) {
83	buf, err := io.ReadAll(c)
84	if err != nil {
85		return "", err
86	}
87	pktout := pktline.NewScanner(bytes.NewReader(buf))
88	if !pktout.Scan() {
89		return "", pktout.Err()
90	}
91	return string(pktout.Bytes()), nil
92}