script_test.go

  1package testscript
  2
  3import (
  4	"context"
  5	"flag"
  6	"fmt"
  7	"net"
  8	"os"
  9	"os/exec"
 10	"path/filepath"
 11	"runtime"
 12	"strings"
 13	"sync"
 14	"testing"
 15	"time"
 16
 17	"github.com/charmbracelet/soft-serve/server"
 18	"github.com/charmbracelet/soft-serve/server/config"
 19	"github.com/charmbracelet/soft-serve/server/test"
 20	"github.com/rogpeppe/go-internal/testscript"
 21)
 22
 23var update = flag.Bool("update", false, "update script files")
 24
 25func TestScript(t *testing.T) {
 26	flag.Parse()
 27	var lock sync.Mutex
 28
 29	t.Setenv("SOFT_SERVE_TEST_NO_HOOKS", "1")
 30
 31	// we'll use this key to talk with soft serve, and since testscript changes
 32	// the cwd, we need to get its full path here
 33	key, err := filepath.Abs("./testdata/admin1")
 34	if err != nil {
 35		t.Fatal(err)
 36	}
 37
 38	// git does not handle 0600, and on clone, will save the files with its
 39	// default perm, 0644, which is too open for ssh.
 40	for _, f := range []string{
 41		"admin1",
 42		"admin2",
 43		"user1",
 44		"user2",
 45	} {
 46		if err := os.Chmod(filepath.Join("./testdata/", f), 0o600); err != nil {
 47			t.Fatal(err)
 48		}
 49	}
 50
 51	sshArgs := []string{
 52		"-F", "/dev/null",
 53		"-o", "StrictHostKeyChecking=no",
 54		"-o", "UserKnownHostsFile=/dev/null",
 55		"-o", "IdentityAgent=none",
 56		"-o", "IdentitiesOnly=yes",
 57		"-o", "ServerAliveInterval=60",
 58		"-i", key,
 59	}
 60
 61	check := func(ts *testscript.TestScript, err error, neg bool) {
 62		if neg && err == nil {
 63			ts.Fatalf("expected error, got nil")
 64		}
 65		if !neg {
 66			ts.Check(err)
 67		}
 68	}
 69
 70	testscript.Run(t, testscript.Params{
 71		Dir:           "testdata/script",
 72		UpdateScripts: *update,
 73		Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
 74			"soft": func(ts *testscript.TestScript, neg bool, args []string) {
 75				args = append(
 76					sshArgs,
 77					append([]string{
 78						"-p", ts.Getenv("SSH_PORT"),
 79						"localhost",
 80						"--",
 81					}, args...)...,
 82				)
 83				if runtime.GOOS == "windows" {
 84					cmd := exec.Command("ssh", args...)
 85					out, err := cmd.CombinedOutput()
 86					ts.Logf("RUNNING %v: output: %s error: %v", cmd.Args, string(out), err)
 87				}
 88				check(ts, ts.Exec("ssh", args...), neg)
 89			},
 90			"git": func(ts *testscript.TestScript, neg bool, args []string) {
 91				ts.Setenv(
 92					"GIT_SSH_COMMAND",
 93					strings.Join(append([]string{"ssh"}, sshArgs...), " "),
 94				)
 95				args = append([]string{
 96					"-c", "user.email=john@example.com",
 97					"-c", "user.name=John Doe",
 98				}, args...)
 99				check(ts, ts.Exec("git", args...), neg)
100			},
101			"mkreadme": func(ts *testscript.TestScript, neg bool, args []string) {
102				if len(args) != 1 {
103					ts.Fatalf("must have exactly 1 arg, the filename, got %d", len(args))
104				}
105				check(ts, os.WriteFile(ts.MkAbs(args[0]), []byte("# example\ntest project"), 0o644), neg)
106			},
107		},
108		Setup: func(e *testscript.Env) error {
109			sshPort := test.RandomPort()
110			e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
111			data := t.TempDir()
112			cfg := config.Config{
113				Name:     "Test Soft Serve",
114				DataPath: data,
115				InitialAdminKeys: []string{
116					"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJI/1tawpdPmzuJcTGTJ+QReqB6cRUdKj4iQIdJUFdrl",
117				},
118				SSH: config.SSHConfig{
119					ListenAddr:    fmt.Sprintf("localhost:%d", sshPort),
120					PublicURL:     fmt.Sprintf("ssh://localhost:%d", sshPort),
121					KeyPath:       filepath.Join(data, "ssh", "soft_serve_host_ed25519"),
122					ClientKeyPath: filepath.Join(data, "ssh", "soft_serve_client_ed25519"),
123				},
124				Git: config.GitConfig{
125					ListenAddr:     fmt.Sprintf("localhost:%d", test.RandomPort()),
126					IdleTimeout:    3,
127					MaxConnections: 32,
128				},
129				HTTP: config.HTTPConfig{
130					ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
131					PublicURL:  fmt.Sprintf("http://localhost:%d", test.RandomPort()),
132				},
133				Stats: config.StatsConfig{
134					ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
135				},
136				Log: config.LogConfig{
137					Format:     "text",
138					TimeFormat: time.DateTime,
139				},
140			}
141			ctx := config.WithContext(context.Background(), &cfg)
142
143			// prevent race condition in lipgloss...
144			// this will probably be autofixed when we start using the colors
145			// from the ssh session instead of the server.
146			// XXX: take another look at this soon
147			lock.Lock()
148			srv, err := server.NewServer(ctx)
149			if err != nil {
150				return err
151			}
152			lock.Unlock()
153
154			go func() {
155				if err := srv.Start(); err != nil {
156					e.T().Fatal(err)
157				}
158			}()
159
160			e.Defer(func() {
161				ctx, cancel := context.WithTimeout(context.Background(), time.Second)
162				defer cancel()
163				if err := srv.Shutdown(ctx); err != nil {
164					e.T().Fatal(err)
165				}
166			})
167
168			// wait until the server is up
169			for {
170				conn, _ := net.DialTimeout(
171					"tcp",
172					net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
173					time.Second,
174				)
175				if conn != nil {
176					conn.Close()
177					break
178				}
179			}
180
181			return nil
182		},
183	})
184}