daemon_test.go

 1package daemon
 2
 3import (
 4	"bytes"
 5	"io"
 6	"log"
 7	"net"
 8	"os"
 9	"testing"
10
11	"github.com/charmbracelet/soft-serve/proto"
12	"github.com/charmbracelet/soft-serve/server/config"
13	"github.com/charmbracelet/soft-serve/server/db/fakedb"
14	"github.com/charmbracelet/soft-serve/server/git"
15	"github.com/go-git/go-git/v5/plumbing/format/pktline"
16)
17
18var testDaemon *Daemon
19
20func TestMain(m *testing.M) {
21	tmp, err := os.MkdirTemp("", "soft-serve-test")
22	if err != nil {
23		log.Fatal(err)
24	}
25	defer os.RemoveAll(tmp)
26	cfg := &config.Config{
27		Host:       "",
28		DataPath:   tmp,
29		AnonAccess: proto.ReadOnlyAccess,
30		Git: config.GitConfig{
31			// Reduce the max read timeout to 1 second so we can test the timeout.
32			IdleTimeout: 3,
33			// Reduce the max timeout to 100 second so we can test the timeout.
34			MaxTimeout: 100,
35			// Reduce the max connections to 3 so we can test the timeout.
36			MaxConnections: 3,
37			Port:           9418,
38		},
39	}
40	cfg = cfg.WithDB(&fakedb.FakeDB{})
41	d, err := NewDaemon(cfg)
42	if err != nil {
43		log.Fatal(err)
44	}
45	testDaemon = d
46	go func() {
47		if err := d.Start(); err != ErrServerClosed {
48			log.Fatal(err)
49		}
50	}()
51	defer d.Close()
52	os.Exit(m.Run())
53}
54
55func TestIdleTimeout(t *testing.T) {
56	c, err := net.Dial("tcp", testDaemon.addr)
57	if err != nil {
58		t.Fatal(err)
59	}
60	out, err := readPktline(c)
61	if err != nil {
62		t.Fatalf("expected nil, got error: %v", err)
63	}
64	if out != git.ErrTimeout.Error() {
65		t.Fatalf("expected %q error, got %q", git.ErrTimeout, out)
66	}
67}
68
69func TestInvalidRepo(t *testing.T) {
70	c, err := net.Dial("tcp", testDaemon.addr)
71	if err != nil {
72		t.Fatal(err)
73	}
74	if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
75		t.Fatalf("expected nil, got error: %v", err)
76	}
77	out, err := readPktline(c)
78	if err != nil {
79		t.Fatalf("expected nil, got error: %v", err)
80	}
81	if out != git.ErrInvalidRepo.Error() {
82		t.Fatalf("expected %q error, got %q", git.ErrInvalidRepo, out)
83	}
84}
85
86func readPktline(c net.Conn) (string, error) {
87	buf, err := io.ReadAll(c)
88	if err != nil {
89		return "", err
90	}
91	pktout := pktline.NewScanner(bytes.NewReader(buf))
92	if !pktout.Scan() {
93		return "", pktout.Err()
94	}
95	return string(pktout.Bytes()), nil
96}