daemon_test.go

 1package daemon
 2
 3import (
 4	"bytes"
 5	"context"
 6	"io"
 7	"log"
 8	"net"
 9	"os"
10	"testing"
11
12	appCfg "github.com/charmbracelet/soft-serve/config"
13	"github.com/charmbracelet/soft-serve/server/config"
14	"github.com/charmbracelet/soft-serve/server/git"
15	"github.com/go-git/go-git/v5/plumbing/format/pktline"
16)
17
18var testDaemon *Daemon
19
20func TestMain(m *testing.M) {
21	cfg := config.DefaultConfig()
22	// Reduce the max connections to 3 so we can test the timeout.
23	cfg.GitMaxConnections = 3
24	// Reduce the max timeout to 100 second so we can test the timeout.
25	cfg.GitMaxTimeout = 100
26	// Reduce the max read timeout to 1 second so we can test the timeout.
27	cfg.GitMaxReadTimeout = 1
28	ac, err := appCfg.NewConfig(cfg)
29	if err != nil {
30		log.Fatal(err)
31	}
32	d, err := NewDaemon(cfg, ac)
33	if err != nil {
34		log.Fatal(err)
35	}
36	testDaemon = d
37	go func() {
38		if err := d.Start(); err != ErrServerClosed {
39			log.Fatal(err)
40		}
41	}()
42	defer d.Shutdown(context.Background())
43	os.Exit(m.Run())
44}
45
46func TestMaxReadTimeout(t *testing.T) {
47	c, err := net.Dial("tcp", testDaemon.addr)
48	if err != nil {
49		t.Fatal(err)
50	}
51	out, err := readPktline(c)
52	if err != nil {
53		t.Fatalf("expected nil, got error: %v", err)
54	}
55	if out != git.ErrMaxTimeout.Error() {
56		t.Fatalf("expected %q error, got nil", git.ErrMaxTimeout)
57	}
58}
59
60func TestInvalidRepo(t *testing.T) {
61	c, err := net.Dial("tcp", testDaemon.addr)
62	if err != nil {
63		t.Fatal(err)
64	}
65	if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
66		t.Fatalf("expected nil, got error: %v", err)
67	}
68	out, err := readPktline(c)
69	if err != nil {
70		t.Fatalf("expected nil, got error: %v", err)
71	}
72	if out != git.ErrInvalidRepo.Error() {
73		t.Fatalf("expected %q error, got nil", git.ErrInvalidRepo)
74	}
75}
76
77func readPktline(c net.Conn) (string, error) {
78	buf, err := io.ReadAll(c)
79	if err != nil {
80		return "", err
81	}
82	pktout := pktline.NewScanner(bytes.NewReader(buf))
83	if !pktout.Scan() {
84		return "", pktout.Err()
85	}
86	return string(pktout.Bytes()), nil
87}