daemon_test.go

 1package daemon
 2
 3import (
 4	"bytes"
 5	"io"
 6	"log"
 7	"net"
 8	"os"
 9	"testing"
10
11	appCfg "github.com/charmbracelet/soft-serve/config"
12	"github.com/charmbracelet/soft-serve/server/config"
13	"github.com/charmbracelet/soft-serve/server/git"
14	"github.com/go-git/go-git/v5/plumbing/format/pktline"
15)
16
17var testDaemon *Daemon
18
19func TestMain(m *testing.M) {
20	testdata := "testdata"
21	defer os.RemoveAll(testdata)
22	cfg := &config.Config{
23		Host:     "",
24		DataPath: testdata,
25		Git: config.GitConfig{
26			// Reduce the max timeout to 100 second so we can test the timeout.
27			MaxTimeout: 100,
28			// Reduce the max read timeout to 1 second so we can test the timeout.
29			MaxReadTimeout: 1,
30			// Reduce the max connections to 3 so we can test the timeout.
31			MaxConnections: 3,
32			Port:           9418,
33		},
34	}
35	ac, err := appCfg.NewConfig(cfg)
36	if err != nil {
37		log.Fatal(err)
38	}
39	d, err := NewDaemon(cfg, ac)
40	if err != nil {
41		log.Fatal(err)
42	}
43	testDaemon = d
44	go func() {
45		if err := d.Start(); err != ErrServerClosed {
46			log.Fatal(err)
47		}
48	}()
49	defer d.Close()
50	os.Exit(m.Run())
51}
52
53func TestMaxReadTimeout(t *testing.T) {
54	c, err := net.Dial("tcp", testDaemon.addr)
55	if err != nil {
56		t.Fatal(err)
57	}
58	out, err := readPktline(c)
59	if err != nil {
60		t.Fatalf("expected nil, got error: %v", err)
61	}
62	if out != git.ErrMaxTimeout.Error() {
63		t.Fatalf("expected %q error, got nil", git.ErrMaxTimeout)
64	}
65}
66
67func TestInvalidRepo(t *testing.T) {
68	c, err := net.Dial("tcp", testDaemon.addr)
69	if err != nil {
70		t.Fatal(err)
71	}
72	if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
73		t.Fatalf("expected nil, got error: %v", err)
74	}
75	out, err := readPktline(c)
76	if err != nil {
77		t.Fatalf("expected nil, got error: %v", err)
78	}
79	if out != git.ErrInvalidRepo.Error() {
80		t.Fatalf("expected %q error, got nil", git.ErrInvalidRepo)
81	}
82}
83
84func readPktline(c net.Conn) (string, error) {
85	buf, err := io.ReadAll(c)
86	if err != nil {
87		return "", err
88	}
89	pktout := pktline.NewScanner(bytes.NewReader(buf))
90	if !pktout.Scan() {
91		return "", pktout.Err()
92	}
93	return string(pktout.Bytes()), nil
94}