1package server
 2
 3import (
 4	"bytes"
 5	"errors"
 6	"fmt"
 7	"io"
 8	"log"
 9	"net"
10	"os"
11	"strings"
12	"testing"
13	"time"
14
15	"github.com/charmbracelet/soft-serve/server/config"
16	"github.com/go-git/go-git/v5/plumbing/format/pktline"
17)
18
19var testDaemon *GitDaemon
20
21func TestMain(m *testing.M) {
22	tmp, err := os.MkdirTemp("", "soft-serve-test")
23	if err != nil {
24		log.Fatal(err)
25	}
26	defer os.RemoveAll(tmp)
27	os.Setenv("SOFT_SERVE_DATA_PATH", tmp)
28	os.Setenv("SOFT_SERVE_GIT_MAX_CONNECTIONS", "3")
29	os.Setenv("SOFT_SERVE_GIT_MAX_TIMEOUT", "100")
30	os.Setenv("SOFT_SERVE_GIT_IDLE_TIMEOUT", "1")
31	os.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", randomPort()))
32	cfg := config.DefaultConfig()
33	d, err := NewGitDaemon(cfg)
34	if err != nil {
35		log.Fatal(err)
36	}
37	testDaemon = d
38	go func() {
39		if err := d.Start(); err != ErrServerClosed {
40			log.Fatal(err)
41		}
42	}()
43	code := m.Run()
44	os.Unsetenv("SOFT_SERVE_DATA_PATH")
45	os.Unsetenv("SOFT_SERVE_GIT_MAX_CONNECTIONS")
46	os.Unsetenv("SOFT_SERVE_GIT_MAX_TIMEOUT")
47	os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
48	os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
49	_ = d.Close()
50	os.Exit(code)
51}
52
53func TestIdleTimeout(t *testing.T) {
54	c, err := net.Dial("tcp", testDaemon.addr)
55	if err != nil {
56		t.Fatal(err)
57	}
58	time.Sleep(2 * time.Second)
59	out, err := readPktline(c)
60	if err != nil && !errors.Is(err, io.EOF) {
61		t.Fatalf("expected nil, got error: %v", err)
62	}
63	if out != ErrTimeout.Error() {
64		t.Fatalf("expected %q error, got %q", ErrTimeout, out)
65	}
66}
67
68func TestInvalidRepo(t *testing.T) {
69	c, err := net.Dial("tcp", testDaemon.addr)
70	if err != nil {
71		t.Fatal(err)
72	}
73	if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
74		t.Fatalf("expected nil, got error: %v", err)
75	}
76	out, err := readPktline(c)
77	if err != nil {
78		t.Fatalf("expected nil, got error: %v", err)
79	}
80	if out != ErrInvalidRepo.Error() {
81		t.Fatalf("expected %q error, got %q", ErrInvalidRepo, out)
82	}
83}
84
85func readPktline(c net.Conn) (string, error) {
86	buf, err := io.ReadAll(c)
87	if err != nil {
88		return "", err
89	}
90	pktout := pktline.NewScanner(bytes.NewReader(buf))
91	if !pktout.Scan() {
92		return "", pktout.Err()
93	}
94	return strings.TrimSpace(string(pktout.Bytes())), nil
95}