1package daemon
  2
  3import (
  4	"bytes"
  5	"context"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log"
 10	"net"
 11	"os"
 12	"strings"
 13	"testing"
 14
 15	"github.com/charmbracelet/soft-serve/server/backend"
 16	"github.com/charmbracelet/soft-serve/server/config"
 17	"github.com/charmbracelet/soft-serve/server/db"
 18	"github.com/charmbracelet/soft-serve/server/db/migrate"
 19	"github.com/charmbracelet/soft-serve/server/git"
 20	"github.com/charmbracelet/soft-serve/server/test"
 21	"github.com/go-git/go-git/v5/plumbing/format/pktline"
 22	_ "modernc.org/sqlite" // sqlite driver
 23)
 24
 25var testDaemon *GitDaemon
 26
 27func TestMain(m *testing.M) {
 28	tmp, err := os.MkdirTemp("", "soft-serve-test")
 29	if err != nil {
 30		log.Fatal(err)
 31	}
 32	defer os.RemoveAll(tmp)
 33	os.Setenv("SOFT_SERVE_DATA_PATH", tmp)
 34	os.Setenv("SOFT_SERVE_GIT_MAX_CONNECTIONS", "3")
 35	os.Setenv("SOFT_SERVE_GIT_MAX_TIMEOUT", "100")
 36	os.Setenv("SOFT_SERVE_GIT_IDLE_TIMEOUT", "1")
 37	os.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", test.RandomPort()))
 38	ctx := context.TODO()
 39	cfg := config.DefaultConfig()
 40	if err := cfg.Validate(); err != nil {
 41		log.Fatal(err)
 42	}
 43	ctx = config.WithContext(ctx, cfg)
 44	db, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
 45	if err != nil {
 46		log.Fatal(err)
 47	}
 48	defer db.Close() // nolint: errcheck
 49	if err := migrate.Migrate(ctx, db); err != nil {
 50		log.Fatal(err)
 51	}
 52	be := backend.New(ctx, cfg, db)
 53	ctx = backend.WithContext(ctx, be)
 54	d, err := NewGitDaemon(ctx)
 55	if err != nil {
 56		log.Fatal(err)
 57	}
 58	testDaemon = d
 59	go func() {
 60		if err := d.Start(); err != ErrServerClosed {
 61			log.Fatal(err)
 62		}
 63	}()
 64	code := m.Run()
 65	os.Unsetenv("SOFT_SERVE_DATA_PATH")
 66	os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
 67	os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
 68	os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
 69	os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
 70	_ = d.Close()
 71	_ = db.Close()
 72	os.Exit(code)
 73}
 74
 75func TestIdleTimeout(t *testing.T) {
 76	c, err := net.Dial("tcp", testDaemon.addr)
 77	if err != nil {
 78		t.Fatal(err)
 79	}
 80	out, err := readPktline(c)
 81	if err != nil && !errors.Is(err, io.EOF) {
 82		t.Fatalf("expected nil, got error: %v", err)
 83	}
 84	if out != "ERR "+git.ErrTimeout.Error() && out != "" {
 85		t.Fatalf("expected %q error, got %q", git.ErrTimeout, out)
 86	}
 87}
 88
 89func TestInvalidRepo(t *testing.T) {
 90	c, err := net.Dial("tcp", testDaemon.addr)
 91	if err != nil {
 92		t.Fatal(err)
 93	}
 94	if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
 95		t.Fatalf("expected nil, got error: %v", err)
 96	}
 97	out, err := readPktline(c)
 98	if err != nil {
 99		t.Fatalf("expected nil, got error: %v", err)
100	}
101	if out != "ERR "+git.ErrInvalidRepo.Error() {
102		t.Fatalf("expected %q error, got %q", git.ErrInvalidRepo, out)
103	}
104}
105
106func readPktline(c net.Conn) (string, error) {
107	buf, err := io.ReadAll(c)
108	if err != nil {
109		return "", err
110	}
111	pktout := pktline.NewScanner(bytes.NewReader(buf))
112	if !pktout.Scan() {
113		return "", pktout.Err()
114	}
115	return strings.TrimSpace(string(pktout.Bytes())), nil
116}