1package server
  2
  3import (
  4	"bytes"
  5	"errors"
  6	"fmt"
  7	"io"
  8	"log"
  9	"net"
 10	"os"
 11	"path/filepath"
 12	"strings"
 13	"testing"
 14	"time"
 15
 16	"github.com/charmbracelet/soft-serve/server/backend/file"
 17	"github.com/charmbracelet/soft-serve/server/config"
 18	"github.com/go-git/go-git/v5/plumbing/format/pktline"
 19)
 20
 21var testDaemon *GitDaemon
 22
 23func TestMain(m *testing.M) {
 24	tmp, err := os.MkdirTemp("", "soft-serve-test")
 25	if err != nil {
 26		log.Fatal(err)
 27	}
 28	defer os.RemoveAll(tmp)
 29	os.Setenv("SOFT_SERVE_DATA_PATH", tmp)
 30	os.Setenv("SOFT_SERVE_GIT_MAX_CONNECTIONS", "3")
 31	os.Setenv("SOFT_SERVE_GIT_MAX_TIMEOUT", "100")
 32	os.Setenv("SOFT_SERVE_GIT_IDLE_TIMEOUT", "1")
 33	os.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", randomPort()))
 34	fb, err := file.NewFileBackend(filepath.Join(tmp, "repos"))
 35	if err != nil {
 36		log.Fatal(err)
 37	}
 38	cfg := config.DefaultConfig().WithBackend(fb).WithAccessMethod(fb)
 39	d, err := NewGitDaemon(cfg)
 40	if err != nil {
 41		log.Fatal(err)
 42	}
 43	testDaemon = d
 44	go func() {
 45		if err := d.Start(); err != ErrServerClosed {
 46			log.Fatal(err)
 47		}
 48	}()
 49	code := m.Run()
 50	os.Unsetenv("SOFT_SERVE_DATA_PATH")
 51	os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
 52	os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
 53	os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
 54	os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
 55	_ = d.Close()
 56	os.Exit(code)
 57}
 58
 59func TestIdleTimeout(t *testing.T) {
 60	c, err := net.Dial("tcp", testDaemon.addr)
 61	if err != nil {
 62		t.Fatal(err)
 63	}
 64	time.Sleep(2 * time.Second)
 65	out, err := readPktline(c)
 66	if err != nil && !errors.Is(err, io.EOF) {
 67		t.Fatalf("expected nil, got error: %v", err)
 68	}
 69	if out != ErrTimeout.Error() {
 70		t.Fatalf("expected %q error, got %q", ErrTimeout, out)
 71	}
 72}
 73
 74func TestInvalidRepo(t *testing.T) {
 75	c, err := net.Dial("tcp", testDaemon.addr)
 76	if err != nil {
 77		t.Fatal(err)
 78	}
 79	if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
 80		t.Fatalf("expected nil, got error: %v", err)
 81	}
 82	out, err := readPktline(c)
 83	if err != nil {
 84		t.Fatalf("expected nil, got error: %v", err)
 85	}
 86	if out != ErrInvalidRepo.Error() {
 87		t.Fatalf("expected %q error, got %q", ErrInvalidRepo, out)
 88	}
 89}
 90
 91func readPktline(c net.Conn) (string, error) {
 92	buf, err := io.ReadAll(c)
 93	if err != nil {
 94		return "", err
 95	}
 96	pktout := pktline.NewScanner(bytes.NewReader(buf))
 97	if !pktout.Scan() {
 98		return "", pktout.Err()
 99	}
100	return strings.TrimSpace(string(pktout.Bytes())), nil
101}