1package server
  2
  3import (
  4	"bytes"
  5	"errors"
  6	"fmt"
  7	"io"
  8	"log"
  9	"net"
 10	"os"
 11	"strings"
 12	"testing"
 13
 14	"github.com/charmbracelet/soft-serve/server/backend/sqlite"
 15	"github.com/charmbracelet/soft-serve/server/config"
 16	"github.com/go-git/go-git/v5/plumbing/format/pktline"
 17)
 18
 19var testDaemon *GitDaemon
 20
 21func TestMain(m *testing.M) {
 22	tmp, err := os.MkdirTemp("", "soft-serve-test")
 23	if err != nil {
 24		log.Fatal(err)
 25	}
 26	defer os.RemoveAll(tmp)
 27	os.Setenv("SOFT_SERVE_DATA_PATH", tmp)
 28	os.Setenv("SOFT_SERVE_GIT_MAX_CONNECTIONS", "3")
 29	os.Setenv("SOFT_SERVE_GIT_MAX_TIMEOUT", "100")
 30	os.Setenv("SOFT_SERVE_GIT_IDLE_TIMEOUT", "1")
 31	os.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", randomPort()))
 32	cfg := config.DefaultConfig()
 33	d, err := NewGitDaemon(cfg)
 34	if err != nil {
 35		log.Fatal(err)
 36	}
 37	fb, err := sqlite.NewSqliteBackend(cfg)
 38	if err != nil {
 39		log.Fatal(err)
 40	}
 41	cfg = cfg.WithBackend(fb)
 42	testDaemon = d
 43	go func() {
 44		if err := d.Start(); err != ErrServerClosed {
 45			log.Fatal(err)
 46		}
 47	}()
 48	code := m.Run()
 49	os.Unsetenv("SOFT_SERVE_DATA_PATH")
 50	os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
 51	os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
 52	os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
 53	os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
 54	_ = d.Close()
 55	_ = fb.Close()
 56	os.Exit(code)
 57}
 58
 59func TestIdleTimeout(t *testing.T) {
 60	c, err := net.Dial("tcp", testDaemon.addr)
 61	if err != nil {
 62		t.Fatal(err)
 63	}
 64	out, err := readPktline(c)
 65	if err != nil && !errors.Is(err, io.EOF) {
 66		t.Fatalf("expected nil, got error: %v", err)
 67	}
 68	if out != ErrTimeout.Error() || out == "" {
 69		t.Fatalf("expected %q error, got %q", ErrTimeout, out)
 70	}
 71}
 72
 73func TestInvalidRepo(t *testing.T) {
 74	c, err := net.Dial("tcp", testDaemon.addr)
 75	if err != nil {
 76		t.Fatal(err)
 77	}
 78	if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
 79		t.Fatalf("expected nil, got error: %v", err)
 80	}
 81	out, err := readPktline(c)
 82	if err != nil {
 83		t.Fatalf("expected nil, got error: %v", err)
 84	}
 85	if out != ErrInvalidRepo.Error() {
 86		t.Fatalf("expected %q error, got %q", ErrInvalidRepo, out)
 87	}
 88}
 89
 90func readPktline(c net.Conn) (string, error) {
 91	buf, err := io.ReadAll(c)
 92	if err != nil {
 93		return "", err
 94	}
 95	pktout := pktline.NewScanner(bytes.NewReader(buf))
 96	if !pktout.Scan() {
 97		return "", pktout.Err()
 98	}
 99	return strings.TrimSpace(string(pktout.Bytes())), nil
100}