1package server
  2
  3import (
  4	"bytes"
  5	"context"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log"
 10	"net"
 11	"os"
 12	"strings"
 13	"testing"
 14
 15	"github.com/charmbracelet/soft-serve/server/backend/sqlite"
 16	"github.com/charmbracelet/soft-serve/server/config"
 17	"github.com/go-git/go-git/v5/plumbing/format/pktline"
 18)
 19
 20var testDaemon *GitDaemon
 21
 22func TestMain(m *testing.M) {
 23	tmp, err := os.MkdirTemp("", "soft-serve-test")
 24	if err != nil {
 25		log.Fatal(err)
 26	}
 27	defer os.RemoveAll(tmp)
 28	os.Setenv("SOFT_SERVE_DATA_PATH", tmp)
 29	os.Setenv("SOFT_SERVE_GIT_MAX_CONNECTIONS", "3")
 30	os.Setenv("SOFT_SERVE_GIT_MAX_TIMEOUT", "100")
 31	os.Setenv("SOFT_SERVE_GIT_IDLE_TIMEOUT", "1")
 32	os.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", randomPort()))
 33	cfg := config.DefaultConfig()
 34	d, err := NewGitDaemon(cfg)
 35	if err != nil {
 36		log.Fatal(err)
 37	}
 38	ctx := context.TODO()
 39	fb, err := sqlite.NewSqliteBackend(ctx, cfg)
 40	if err != nil {
 41		log.Fatal(err)
 42	}
 43	cfg = cfg.WithBackend(fb)
 44	testDaemon = d
 45	go func() {
 46		if err := d.Start(); err != ErrServerClosed {
 47			log.Fatal(err)
 48		}
 49	}()
 50	code := m.Run()
 51	os.Unsetenv("SOFT_SERVE_DATA_PATH")
 52	os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
 53	os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
 54	os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
 55	os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
 56	_ = d.Close()
 57	_ = fb.Close()
 58	os.Exit(code)
 59}
 60
 61func TestIdleTimeout(t *testing.T) {
 62	c, err := net.Dial("tcp", testDaemon.addr)
 63	if err != nil {
 64		t.Fatal(err)
 65	}
 66	out, err := readPktline(c)
 67	if err != nil && !errors.Is(err, io.EOF) {
 68		t.Fatalf("expected nil, got error: %v", err)
 69	}
 70	if out != ErrTimeout.Error() || out == "" {
 71		t.Fatalf("expected %q error, got %q", ErrTimeout, out)
 72	}
 73}
 74
 75func TestInvalidRepo(t *testing.T) {
 76	c, err := net.Dial("tcp", testDaemon.addr)
 77	if err != nil {
 78		t.Fatal(err)
 79	}
 80	if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
 81		t.Fatalf("expected nil, got error: %v", err)
 82	}
 83	out, err := readPktline(c)
 84	if err != nil {
 85		t.Fatalf("expected nil, got error: %v", err)
 86	}
 87	if out != ErrInvalidRepo.Error() {
 88		t.Fatalf("expected %q error, got %q", ErrInvalidRepo, out)
 89	}
 90}
 91
 92func readPktline(c net.Conn) (string, error) {
 93	buf, err := io.ReadAll(c)
 94	if err != nil {
 95		return "", err
 96	}
 97	pktout := pktline.NewScanner(bytes.NewReader(buf))
 98	if !pktout.Scan() {
 99		return "", pktout.Err()
100	}
101	return strings.TrimSpace(string(pktout.Bytes())), nil
102}