daemon_test.go

  1package daemon
  2
  3import (
  4	"bytes"
  5	"context"
  6	"fmt"
  7	"io"
  8	"log"
  9	"net"
 10	"os"
 11	"strings"
 12	"testing"
 13
 14	"github.com/charmbracelet/soft-serve/pkg/backend"
 15	"github.com/charmbracelet/soft-serve/pkg/config"
 16	"github.com/charmbracelet/soft-serve/pkg/db"
 17	"github.com/charmbracelet/soft-serve/pkg/db/migrate"
 18	"github.com/charmbracelet/soft-serve/pkg/git"
 19	"github.com/charmbracelet/soft-serve/pkg/store"
 20	"github.com/charmbracelet/soft-serve/pkg/store/database"
 21	"github.com/charmbracelet/soft-serve/pkg/test"
 22	"github.com/go-git/go-git/v5/plumbing/format/pktline"
 23	_ "modernc.org/sqlite" // sqlite driver
 24)
 25
 26var testDaemon *GitDaemon
 27
 28func TestMain(m *testing.M) {
 29	tmp, err := os.MkdirTemp("", "soft-serve-test")
 30	if err != nil {
 31		log.Fatal(err)
 32	}
 33	defer os.RemoveAll(tmp)
 34	ctx := context.TODO()
 35	cfg := config.DefaultConfig()
 36	cfg.DataPath = tmp
 37	cfg.Git.MaxConnections = 3
 38	cfg.Git.MaxTimeout = 100
 39	cfg.Git.IdleTimeout = 1
 40	cfg.Git.ListenAddr = fmt.Sprintf(":%d", test.RandomPort())
 41	if err := cfg.Validate(); err != nil {
 42		log.Fatal(err)
 43	}
 44	ctx = config.WithContext(ctx, cfg)
 45	dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
 46	if err != nil {
 47		log.Fatal(err)
 48	}
 49	defer dbx.Close() // nolint: errcheck
 50	if err := migrate.Migrate(ctx, dbx); err != nil {
 51		log.Fatal(err)
 52	}
 53	datastore := database.New(ctx, dbx)
 54	ctx = store.WithContext(ctx, datastore)
 55	be := backend.New(ctx, cfg, dbx)
 56	ctx = backend.WithContext(ctx, be)
 57	d, err := NewGitDaemon(ctx)
 58	if err != nil {
 59		log.Fatal(err)
 60	}
 61	testDaemon = d
 62	go func() {
 63		if err := d.Start(); err != ErrServerClosed {
 64			log.Fatal(err)
 65		}
 66	}()
 67	code := m.Run()
 68	os.Unsetenv("SOFT_SERVE_DATA_PATH")
 69	os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
 70	os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
 71	os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
 72	os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
 73	_ = d.Close()
 74	_ = dbx.Close()
 75	os.Exit(code)
 76}
 77
 78func TestIdleTimeout(t *testing.T) {
 79	c, err := net.Dial("tcp", testDaemon.addr)
 80	if err != nil {
 81		t.Fatal(err)
 82	}
 83	_, err = readPktline(c)
 84	if err != nil && err.Error() != git.ErrTimeout.Error() {
 85		t.Fatalf("expected %q error, got %q", git.ErrTimeout, err)
 86	}
 87}
 88
 89func TestInvalidRepo(t *testing.T) {
 90	c, err := net.Dial("tcp", testDaemon.addr)
 91	if err != nil {
 92		t.Fatal(err)
 93	}
 94	if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
 95		t.Fatalf("expected nil, got error: %v", err)
 96	}
 97	_, err = readPktline(c)
 98	if err != nil && err.Error() != git.ErrInvalidRepo.Error() {
 99		t.Fatalf("expected %q error, got %q", git.ErrInvalidRepo, err)
100	}
101}
102
103func readPktline(c net.Conn) (string, error) {
104	buf, err := io.ReadAll(c)
105	if err != nil {
106		return "", err
107	}
108	pktout := pktline.NewScanner(bytes.NewReader(buf))
109	if !pktout.Scan() {
110		return "", pktout.Err()
111	}
112	return strings.TrimSpace(string(pktout.Bytes())), nil
113}