script_test.go

  1package testscript
  2
  3import (
  4	"context"
  5	"flag"
  6	"fmt"
  7	"net"
  8	"os"
  9	"path/filepath"
 10	"strings"
 11	"sync"
 12	"testing"
 13	"time"
 14
 15	"github.com/charmbracelet/soft-serve/server"
 16	"github.com/charmbracelet/soft-serve/server/config"
 17	"github.com/charmbracelet/soft-serve/server/test"
 18	"github.com/rogpeppe/go-internal/testscript"
 19)
 20
 21var update = flag.Bool("update", false, "update script files")
 22
 23func TestScript(t *testing.T) {
 24	flag.Parse()
 25	var lock sync.Mutex
 26
 27	t.Setenv("SOFT_SERVE_TEST_NO_HOOKS", "1")
 28
 29	// we'll use this key to talk with soft serve, and since testscript changes
 30	// the cwd, we need to get its full path here
 31	key, err := filepath.Abs("./testdata/admin1")
 32	if err != nil {
 33		t.Fatal(err)
 34	}
 35
 36	// git does not handle 0600, and on clone, will save the files with its
 37	// default perm, 0644, which is too open for ssh.
 38	for _, f := range []string{
 39		"admin1",
 40		"admin2",
 41		"user1",
 42		"user2",
 43	} {
 44		if err := os.Chmod(filepath.Join("./testdata/", f), 0o600); err != nil {
 45			t.Fatal(err)
 46		}
 47	}
 48
 49	sshArgs := []string{
 50		"-F", "/dev/null",
 51		"-o", "StrictHostKeyChecking=no",
 52		"-o", "UserKnownHostsFile=/dev/null",
 53		"-o", "IdentityAgent=none",
 54		"-o", "IdentitiesOnly=yes",
 55		"-i", key,
 56	}
 57
 58	check := func(ts *testscript.TestScript, err error, neg bool) {
 59		if neg && err == nil {
 60			ts.Fatalf("expected error, got nil")
 61		}
 62		if !neg {
 63			ts.Check(err)
 64		}
 65	}
 66
 67	testscript.Run(t, testscript.Params{
 68		Dir:           "testdata/script",
 69		UpdateScripts: *update,
 70		Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
 71			"soft": func(ts *testscript.TestScript, neg bool, args []string) {
 72				args = append(
 73					sshArgs,
 74					append([]string{
 75						"-p", ts.Getenv("SSH_PORT"),
 76						"localhost",
 77						"--",
 78					}, args...)...,
 79				)
 80				check(ts, ts.Exec("ssh", args...), neg)
 81			},
 82			"git": func(ts *testscript.TestScript, neg bool, args []string) {
 83				ts.Setenv(
 84					"GIT_SSH_COMMAND",
 85					strings.Join(append([]string{"ssh"}, sshArgs...), " "),
 86				)
 87				args = append([]string{
 88					"-c", "user.email=john@example.com",
 89					"-c", "user.name=John Doe",
 90				}, args...)
 91				check(ts, ts.Exec("git", args...), neg)
 92			},
 93			"mkreadme": func(ts *testscript.TestScript, neg bool, args []string) {
 94				if len(args) != 1 {
 95					ts.Fatalf("must have exactly 1 arg, the filename, got %d", len(args))
 96				}
 97				check(ts, os.WriteFile(ts.MkAbs(args[0]), []byte("# example\ntest project"), 0o644), neg)
 98			},
 99		},
100		Setup: func(e *testscript.Env) error {
101			sshPort := test.RandomPort()
102			e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
103			data := t.TempDir()
104			cfg := config.Config{
105				Name:     "Test Soft Serve",
106				DataPath: data,
107				InitialAdminKeys: []string{
108					"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJI/1tawpdPmzuJcTGTJ+QReqB6cRUdKj4iQIdJUFdrl",
109				},
110				SSH: config.SSHConfig{
111					ListenAddr:    fmt.Sprintf("localhost:%d", sshPort),
112					PublicURL:     fmt.Sprintf("ssh://localhost:%d", sshPort),
113					KeyPath:       filepath.Join(data, "ssh", "soft_serve_host_ed25519"),
114					ClientKeyPath: filepath.Join(data, "ssh", "soft_serve_client_ed25519"),
115				},
116				Git: config.GitConfig{
117					ListenAddr:     fmt.Sprintf("localhost:%d", test.RandomPort()),
118					IdleTimeout:    3,
119					MaxConnections: 32,
120				},
121				HTTP: config.HTTPConfig{
122					ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
123					PublicURL:  fmt.Sprintf("http://localhost:%d", test.RandomPort()),
124				},
125				Stats: config.StatsConfig{
126					ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
127				},
128				Log: config.LogConfig{
129					Format:     "text",
130					TimeFormat: time.DateTime,
131				},
132			}
133			ctx := config.WithContext(context.Background(), &cfg)
134
135			// prevent race condition in lipgloss...
136			// this will probably be autofixed when we start using the colors
137			// from the ssh session instead of the server.
138			// XXX: take another look at this soon
139			lock.Lock()
140			srv, err := server.NewServer(ctx)
141			if err != nil {
142				return err
143			}
144			lock.Unlock()
145
146			go func() {
147				if err := srv.Start(); err != nil {
148					e.T().Fatal(err)
149				}
150			}()
151
152			e.Defer(func() {
153				ctx, cancel := context.WithTimeout(context.Background(), time.Second)
154				defer cancel()
155				if err := srv.Shutdown(ctx); err != nil {
156					e.T().Fatal(err)
157				}
158			})
159
160			// wait until the server is up
161			for {
162				conn, _ := net.DialTimeout(
163					"tcp",
164					net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
165					time.Second,
166				)
167				if conn != nil {
168					conn.Close()
169					break
170				}
171			}
172
173			return nil
174		},
175	})
176}