daemon_test.go

 1package daemon
 2
 3import (
 4	"bytes"
 5	"io"
 6	"log"
 7	"net"
 8	"os"
 9	"testing"
10
11	appCfg "github.com/charmbracelet/soft-serve/config"
12	"github.com/charmbracelet/soft-serve/server/config"
13	"github.com/charmbracelet/soft-serve/server/git"
14	"github.com/go-git/go-git/v5/plumbing/format/pktline"
15)
16
17var testDaemon *Daemon
18
19func TestMain(m *testing.M) {
20	tmp, err := os.MkdirTemp("", "soft-serve-test")
21	if err != nil {
22		log.Fatal(err)
23	}
24	defer os.RemoveAll(tmp)
25	cfg := &config.Config{
26		Host:     "",
27		DataPath: tmp,
28		Git: config.GitConfig{
29			// Reduce the max timeout to 100 second so we can test the timeout.
30			MaxTimeout: 100,
31			// Reduce the max read timeout to 1 second so we can test the timeout.
32			MaxReadTimeout: 1,
33			// Reduce the max connections to 3 so we can test the timeout.
34			MaxConnections: 3,
35			Port:           9418,
36		},
37	}
38	ac, err := appCfg.NewConfig(cfg)
39	if err != nil {
40		log.Fatal(err)
41	}
42	d, err := NewDaemon(cfg, ac)
43	if err != nil {
44		log.Fatal(err)
45	}
46	testDaemon = d
47	go func() {
48		if err := d.Start(); err != ErrServerClosed {
49			log.Fatal(err)
50		}
51	}()
52	defer d.Close()
53	os.Exit(m.Run())
54}
55
56func TestMaxReadTimeout(t *testing.T) {
57	c, err := net.Dial("tcp", testDaemon.addr)
58	if err != nil {
59		t.Fatal(err)
60	}
61	out, err := readPktline(c)
62	if err != nil {
63		t.Fatalf("expected nil, got error: %v", err)
64	}
65	if out != git.ErrMaxTimeout.Error() {
66		t.Fatalf("expected %q error, got nil", git.ErrMaxTimeout)
67	}
68}
69
70func TestInvalidRepo(t *testing.T) {
71	c, err := net.Dial("tcp", testDaemon.addr)
72	if err != nil {
73		t.Fatal(err)
74	}
75	if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
76		t.Fatalf("expected nil, got error: %v", err)
77	}
78	out, err := readPktline(c)
79	if err != nil {
80		t.Fatalf("expected nil, got error: %v", err)
81	}
82	if out != git.ErrInvalidRepo.Error() {
83		t.Fatalf("expected %q error, got nil", git.ErrInvalidRepo)
84	}
85}
86
87func readPktline(c net.Conn) (string, error) {
88	buf, err := io.ReadAll(c)
89	if err != nil {
90		return "", err
91	}
92	pktout := pktline.NewScanner(bytes.NewReader(buf))
93	if !pktout.Scan() {
94		return "", pktout.Err()
95	}
96	return string(pktout.Bytes()), nil
97}