1package daemon
  2
  3import (
  4	"bytes"
  5	"context"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log"
 10	"net"
 11	"os"
 12	"strings"
 13	"testing"
 14
 15	"github.com/charmbracelet/soft-serve/server/backend"
 16	"github.com/charmbracelet/soft-serve/server/config"
 17	"github.com/charmbracelet/soft-serve/server/db"
 18	"github.com/charmbracelet/soft-serve/server/db/migrate"
 19	"github.com/charmbracelet/soft-serve/server/git"
 20	"github.com/charmbracelet/soft-serve/server/store"
 21	"github.com/charmbracelet/soft-serve/server/store/database"
 22	"github.com/charmbracelet/soft-serve/server/test"
 23	"github.com/go-git/go-git/v5/plumbing/format/pktline"
 24	_ "modernc.org/sqlite" // sqlite driver
 25)
 26
 27var testDaemon *GitDaemon
 28
 29func TestMain(m *testing.M) {
 30	tmp, err := os.MkdirTemp("", "soft-serve-test")
 31	if err != nil {
 32		log.Fatal(err)
 33	}
 34	defer os.RemoveAll(tmp)
 35	ctx := context.TODO()
 36	cfg := config.DefaultConfig()
 37	cfg.DataPath = tmp
 38	cfg.Git.MaxConnections = 3
 39	cfg.Git.MaxTimeout = 100
 40	cfg.Git.IdleTimeout = 1
 41	cfg.Git.ListenAddr = fmt.Sprintf(":%d", test.RandomPort())
 42	if err := cfg.Validate(); err != nil {
 43		log.Fatal(err)
 44	}
 45	ctx = config.WithContext(ctx, cfg)
 46	dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
 47	if err != nil {
 48		log.Fatal(err)
 49	}
 50	defer dbx.Close() // nolint: errcheck
 51	if err := migrate.Migrate(ctx, dbx); err != nil {
 52		log.Fatal(err)
 53	}
 54	datastore := database.New(ctx, dbx)
 55	ctx = store.WithContext(ctx, datastore)
 56	be := backend.New(ctx, cfg, dbx)
 57	ctx = backend.WithContext(ctx, be)
 58	d, err := NewGitDaemon(ctx)
 59	if err != nil {
 60		log.Fatal(err)
 61	}
 62	testDaemon = d
 63	go func() {
 64		if err := d.Start(); err != ErrServerClosed {
 65			log.Fatal(err)
 66		}
 67	}()
 68	code := m.Run()
 69	os.Unsetenv("SOFT_SERVE_DATA_PATH")
 70	os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
 71	os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
 72	os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
 73	os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
 74	_ = d.Close()
 75	_ = dbx.Close()
 76	os.Exit(code)
 77}
 78
 79func TestIdleTimeout(t *testing.T) {
 80	c, err := net.Dial("tcp", testDaemon.addr)
 81	if err != nil {
 82		t.Fatal(err)
 83	}
 84	out, err := readPktline(c)
 85	if err != nil && !errors.Is(err, io.EOF) {
 86		t.Fatalf("expected nil, got error: %v", err)
 87	}
 88	if out != "ERR "+git.ErrTimeout.Error() && out != "" {
 89		t.Fatalf("expected %q error, got %q", git.ErrTimeout, out)
 90	}
 91}
 92
 93func TestInvalidRepo(t *testing.T) {
 94	c, err := net.Dial("tcp", testDaemon.addr)
 95	if err != nil {
 96		t.Fatal(err)
 97	}
 98	if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
 99		t.Fatalf("expected nil, got error: %v", err)
100	}
101	out, err := readPktline(c)
102	if err != nil {
103		t.Fatalf("expected nil, got error: %v", err)
104	}
105	if out != "ERR "+git.ErrInvalidRepo.Error() {
106		t.Fatalf("expected %q error, got %q", git.ErrInvalidRepo, out)
107	}
108}
109
110func readPktline(c net.Conn) (string, error) {
111	buf, err := io.ReadAll(c)
112	if err != nil {
113		return "", err
114	}
115	pktout := pktline.NewScanner(bytes.NewReader(buf))
116	if !pktout.Scan() {
117		return "", pktout.Err()
118	}
119	return strings.TrimSpace(string(pktout.Bytes())), nil
120}