daemon_test.go

  1package daemon
  2
  3import (
  4	"bytes"
  5	"errors"
  6	"io"
  7	"log"
  8	"net"
  9	"os"
 10	"strconv"
 11	"testing"
 12	"time"
 13
 14	"github.com/charmbracelet/soft-serve/server/config"
 15	"github.com/charmbracelet/soft-serve/server/git"
 16	"github.com/go-git/go-git/v5/plumbing/format/pktline"
 17)
 18
 19var testDaemon *Daemon
 20
 21func TestMain(m *testing.M) {
 22	tmp, err := os.MkdirTemp("", "soft-serve-test")
 23	if err != nil {
 24		log.Fatal(err)
 25	}
 26	defer os.RemoveAll(tmp)
 27	os.Setenv("SOFT_SERVE_DATA_PATH", tmp)
 28	os.Setenv("SOFT_SERVE_ANON_ACCESS", "read-only")
 29	os.Setenv("SOFT_SERVE_GIT_MAX_CONNECTIONS", "3")
 30	os.Setenv("SOFT_SERVE_GIT_MAX_TIMEOUT", "100")
 31	os.Setenv("SOFT_SERVE_GIT_IDLE_TIMEOUT", "1")
 32	os.Setenv("SOFT_SERVE_GIT_PORT", strconv.Itoa(randomPort()))
 33	cfg := config.DefaultConfig()
 34	d, err := NewDaemon(cfg)
 35	if err != nil {
 36		log.Fatal(err)
 37	}
 38	testDaemon = d
 39	go func() {
 40		if err := d.Start(); err != ErrServerClosed {
 41			log.Fatal(err)
 42		}
 43	}()
 44	code := m.Run()
 45	os.Unsetenv("SOFT_SERVE_DATA_PATH")
 46	os.Unsetenv("SOFT_SERVE_ANON_ACCESS")
 47	os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
 48	os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
 49	os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
 50	os.Unsetenv("SOFT_SERVE_GIT_PORT")
 51	_ = d.Close()
 52	os.Exit(code)
 53}
 54
 55func TestIdleTimeout(t *testing.T) {
 56	c, err := net.Dial("tcp", testDaemon.addr)
 57	if err != nil {
 58		t.Fatal(err)
 59	}
 60	time.Sleep(2 * time.Second)
 61	out, err := readPktline(c)
 62	if err != nil && !errors.Is(err, io.EOF) {
 63		t.Fatalf("expected nil, got error: %v", err)
 64	}
 65	if out != git.ErrTimeout.Error() {
 66		t.Fatalf("expected %q error, got %q", git.ErrTimeout, out)
 67	}
 68}
 69
 70func TestInvalidRepo(t *testing.T) {
 71	c, err := net.Dial("tcp", testDaemon.addr)
 72	if err != nil {
 73		t.Fatal(err)
 74	}
 75	if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
 76		t.Fatalf("expected nil, got error: %v", err)
 77	}
 78	out, err := readPktline(c)
 79	if err != nil {
 80		t.Fatalf("expected nil, got error: %v", err)
 81	}
 82	if out != git.ErrInvalidRepo.Error() {
 83		t.Fatalf("expected %q error, got %q", git.ErrInvalidRepo, out)
 84	}
 85}
 86
 87func readPktline(c net.Conn) (string, error) {
 88	buf, err := io.ReadAll(c)
 89	if err != nil {
 90		return "", err
 91	}
 92	pktout := pktline.NewScanner(bytes.NewReader(buf))
 93	if !pktout.Scan() {
 94		return "", pktout.Err()
 95	}
 96	return string(pktout.Bytes()), nil
 97}
 98
 99func randomPort() int {
100	addr, _ := net.Listen("tcp", ":0") //nolint:gosec
101	_ = addr.Close()
102	return addr.Addr().(*net.TCPAddr).Port
103}