script_test.go

  1package testscript
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"flag"
  8	"fmt"
  9	"io"
 10	"math/rand"
 11	"net"
 12	"net/http"
 13	"net/url"
 14	"os"
 15	"os/exec"
 16	"path/filepath"
 17	"runtime"
 18	"strconv"
 19	"strings"
 20	"testing"
 21	"time"
 22
 23	"github.com/charmbracelet/keygen"
 24	"github.com/charmbracelet/soft-serve/pkg/config"
 25	"github.com/charmbracelet/soft-serve/pkg/db"
 26	"github.com/charmbracelet/soft-serve/pkg/test"
 27	"github.com/rogpeppe/go-internal/testscript"
 28	"github.com/spf13/cobra"
 29	"golang.org/x/crypto/ssh"
 30)
 31
 32var (
 33	update  = flag.Bool("update", false, "update script files")
 34	binPath string
 35)
 36
 37func PrepareBuildCommand(binPath string) *exec.Cmd {
 38	_, disableRaceSet := os.LookupEnv("SOFT_SERVE_DISABLE_RACE_CHECKS")
 39	if disableRaceSet {
 40		// don't add the -race flag
 41		return exec.Command("go", "build", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft"))
 42	}
 43	return exec.Command("go", "build", "-race", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft"))
 44}
 45
 46func TestMain(m *testing.M) {
 47	tmp, err := os.MkdirTemp("", "soft-serve*")
 48	if err != nil {
 49		fmt.Fprintf(os.Stderr, "failed to create temporary directory: %s", err)
 50		os.Exit(1)
 51	}
 52	defer os.RemoveAll(tmp)
 53
 54	binPath = filepath.Join(tmp, "soft")
 55	if runtime.GOOS == "windows" {
 56		binPath += ".exe"
 57	}
 58
 59	// Build the soft binary with -cover flag.
 60	cmd := PrepareBuildCommand(binPath)
 61	if err := cmd.Run(); err != nil {
 62		fmt.Fprintf(os.Stderr, "failed to build soft-serve binary: %s", err)
 63		os.Exit(1)
 64	}
 65
 66	// Run tests
 67	os.Exit(m.Run())
 68}
 69
 70func TestScript(t *testing.T) {
 71	flag.Parse()
 72
 73	mkkey := func(name string) (string, *keygen.SSHKeyPair) {
 74		path := filepath.Join(t.TempDir(), name)
 75		pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
 76		if err != nil {
 77			t.Fatal(err)
 78		}
 79		return path, pair
 80	}
 81
 82	admin1Key, admin1 := mkkey("admin1")
 83	_, admin2 := mkkey("admin2")
 84	user1Key, user1 := mkkey("user1")
 85
 86	testscript.Run(t, testscript.Params{
 87		Dir:                 "./testdata/",
 88		UpdateScripts:       *update,
 89		RequireExplicitExec: true,
 90		Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
 91			"soft":                   cmdSoft("admin", admin1.Signer()),
 92			"usoft":                  cmdSoft("user1", user1.Signer()),
 93			"git":                    cmdGit(admin1Key),
 94			"ugit":                   cmdGit(user1Key),
 95			"curl":                   cmdCurl,
 96			"mkfile":                 cmdMkfile,
 97			"envfile":                cmdEnvfile,
 98			"readfile":               cmdReadfile,
 99			"dos2unix":               cmdDos2Unix,
100			"new-webhook":            cmdNewWebhook,
101			"ensureserverrunning":    cmdEnsureServerRunning,
102			"ensureservernotrunning": cmdEnsureServerNotRunning,
103			"stopserver":             cmdStopserver,
104			"ui":                     cmdUI(admin1.Signer()),
105			"uui":                    cmdUI(user1.Signer()),
106		},
107		Setup: func(e *testscript.Env) error {
108			// Add binPath to PATH
109			e.Setenv("PATH", fmt.Sprintf("%s%c%s", filepath.Dir(binPath), os.PathListSeparator, e.Getenv("PATH")))
110
111			data := t.TempDir()
112			sshPort := test.RandomPort()
113			sshListen := fmt.Sprintf("localhost:%d", sshPort)
114			gitPort := test.RandomPort()
115			gitListen := fmt.Sprintf("localhost:%d", gitPort)
116			httpPort := test.RandomPort()
117			httpListen := fmt.Sprintf("localhost:%d", httpPort)
118			statsPort := test.RandomPort()
119			statsListen := fmt.Sprintf("localhost:%d", statsPort)
120			serverName := "Test Soft Serve"
121
122			e.Setenv("DATA_PATH", data)
123			e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
124			e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
125			e.Setenv("STATS_PORT", fmt.Sprintf("%d", statsPort))
126			e.Setenv("GIT_PORT", fmt.Sprintf("%d", gitPort))
127			e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
128			e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
129			e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
130			e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
131			e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
132
133			// This is used to set up test specific configuration and http endpoints
134			e.Setenv("SOFT_SERVE_TESTRUN", "1")
135
136			// This will disable the default lipgloss renderer colors
137			e.Setenv("SOFT_SERVE_NO_COLOR", "1")
138
139			// Soft Serve debug environment variables
140			for _, env := range []string{
141				"SOFT_SERVE_DEBUG",
142				"SOFT_SERVE_VERBOSE",
143			} {
144				if v, ok := os.LookupEnv(env); ok {
145					e.Setenv(env, v)
146				}
147			}
148
149			// TODO: test different configs
150			cfg := config.DefaultConfig()
151			cfg.DataPath = data
152			cfg.Name = serverName
153			cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
154			cfg.SSH.ListenAddr = sshListen
155			cfg.SSH.PublicURL = "ssh://" + sshListen
156			cfg.Git.ListenAddr = gitListen
157			cfg.HTTP.ListenAddr = httpListen
158			cfg.HTTP.PublicURL = "http://" + httpListen
159			cfg.Stats.ListenAddr = statsListen
160			cfg.LFS.Enabled = true
161
162			// Parse os SOFT_SERVE environment variables
163			if err := cfg.ParseEnv(); err != nil {
164				return err
165			}
166
167			// Override the database data source if we're using postgres
168			// so we can create a temporary database for the tests.
169			if cfg.DB.Driver == "postgres" {
170				err, cleanup := setupPostgres(e.T(), cfg)
171				if err != nil {
172					return err
173				}
174				if cleanup != nil {
175					e.Defer(cleanup)
176				}
177			}
178
179			for _, env := range cfg.Environ() {
180				parts := strings.SplitN(env, "=", 2)
181				if len(parts) != 2 {
182					e.T().Fatal("invalid environment variable", env)
183				}
184				e.Setenv(parts[0], parts[1])
185			}
186
187			return nil
188		},
189	})
190}
191
192func cmdSoft(user string, key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
193	return func(ts *testscript.TestScript, neg bool, args []string) {
194		cli, err := ssh.Dial(
195			"tcp",
196			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
197			&ssh.ClientConfig{
198				User:            user,
199				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
200				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
201			},
202		)
203		ts.Check(err)
204		defer cli.Close()
205
206		sess, err := cli.NewSession()
207		ts.Check(err)
208		defer sess.Close()
209
210		sess.Stdout = ts.Stdout()
211		sess.Stderr = ts.Stderr()
212
213		check(ts, sess.Run(strings.Join(args, " ")), neg)
214	}
215}
216
217func cmdUI(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
218	return func(ts *testscript.TestScript, neg bool, args []string) {
219		if len(args) < 1 {
220			ts.Fatalf("usage: ui <quoted string input>")
221			return
222		}
223
224		cli, err := ssh.Dial(
225			"tcp",
226			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
227			&ssh.ClientConfig{
228				User:            "git",
229				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
230				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
231			},
232		)
233		check(ts, err, neg)
234		defer cli.Close()
235
236		sess, err := cli.NewSession()
237		check(ts, err, neg)
238		defer sess.Close()
239
240		// XXX: this is a hack to make the UI tests work
241		// cmp command always complains about an extra newline
242		// in the output
243		defer ts.Stdout().Write([]byte("\n"))
244
245		sess.Stdout = ts.Stdout()
246		sess.Stderr = ts.Stderr()
247
248		stdin, err := sess.StdinPipe()
249		check(ts, err, neg)
250
251		err = sess.RequestPty("dumb", 40, 80, ssh.TerminalModes{})
252		check(ts, err, neg)
253		check(ts, sess.Start(""), neg)
254
255		in, err := strconv.Unquote(args[0])
256		check(ts, err, neg)
257		reader := strings.NewReader(in)
258		go func() {
259			defer stdin.Close()
260			for {
261				r, _, err := reader.ReadRune()
262				if err == io.EOF {
263					break
264				}
265				check(ts, err, neg)
266				stdin.Write([]byte(string(r))) //nolint: errcheck
267
268				// Wait for the UI to process the input
269				time.Sleep(100 * time.Millisecond)
270			}
271		}()
272
273		check(ts, sess.Wait(), neg)
274	}
275}
276
277func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
278	if neg {
279		ts.Fatalf("unsupported: ! dos2unix")
280	}
281	if len(args) < 1 {
282		ts.Fatalf("usage: dos2unix paths...")
283	}
284	for _, arg := range args {
285		filename := ts.MkAbs(arg)
286		data, err := os.ReadFile(filename)
287		if err != nil {
288			ts.Fatalf("%s: %v", filename, err)
289		}
290
291		// Replace all '\r\n' with '\n'.
292		data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
293
294		if err := os.WriteFile(filename, data, 0o644); err != nil {
295			ts.Fatalf("%s: %v", filename, err)
296		}
297	}
298}
299
300var sshConfig = `
301Host *
302  UserKnownHostsFile %q
303  StrictHostKeyChecking no
304  IdentityAgent none
305  IdentitiesOnly yes
306  ServerAliveInterval 60
307`
308
309func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
310	return func(ts *testscript.TestScript, neg bool, args []string) {
311		ts.Check(os.WriteFile(
312			ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
313			[]byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
314			0o600,
315		))
316		sshArgs := []string{
317			"-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
318			"-i", filepath.ToSlash(key),
319		}
320		ts.Setenv(
321			"GIT_SSH_COMMAND",
322			strings.Join(append([]string{"ssh"}, sshArgs...), " "),
323		)
324		// Disable git prompting for credentials.
325		ts.Setenv("GIT_TERMINAL_PROMPT", "0")
326		args = append([]string{
327			"-c", "user.email=john@example.com",
328			"-c", "user.name=John Doe",
329		}, args...)
330		check(ts, ts.Exec("git", args...), neg)
331	}
332}
333
334func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
335	if len(args) < 2 {
336		ts.Fatalf("usage: mkfile path content")
337	}
338	check(ts, os.WriteFile(
339		ts.MkAbs(args[0]),
340		[]byte(strings.Join(args[1:], " ")),
341		0o644,
342	), neg)
343}
344
345func check(ts *testscript.TestScript, err error, neg bool) {
346	if neg && err == nil {
347		ts.Fatalf("expected error, got nil")
348	}
349	if !neg {
350		ts.Check(err)
351	}
352}
353
354func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) { //nolint:unparam
355	ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
356}
357
358func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) { //nolint:unparam
359	if len(args) < 1 {
360		ts.Fatalf("usage: envfile key=file...")
361	}
362
363	for _, arg := range args {
364		parts := strings.SplitN(arg, "=", 2)
365		if len(parts) != 2 {
366			ts.Fatalf("usage: envfile key=file...")
367		}
368		key := parts[0]
369		file := parts[1]
370		ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
371	}
372}
373
374func cmdNewWebhook(ts *testscript.TestScript, neg bool, args []string) {
375	type webhookSite struct {
376		UUID string `json:"uuid"`
377	}
378
379	if len(args) != 1 {
380		ts.Fatalf("usage: new-webhook <env-name>")
381	}
382
383	const whSite = "https://webhook.site"
384	req, err := http.NewRequest(http.MethodPost, whSite+"/token", nil)
385	check(ts, err, neg)
386
387	resp, err := http.DefaultClient.Do(req)
388	check(ts, err, neg)
389
390	defer resp.Body.Close()
391	var site webhookSite
392	check(ts, json.NewDecoder(resp.Body).Decode(&site), neg)
393
394	ts.Setenv(args[0], whSite+"/"+site.UUID)
395}
396
397func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {
398	var verbose bool
399	var headers []string
400	var data string
401	method := http.MethodGet
402
403	cmd := &cobra.Command{
404		Use:  "curl",
405		Args: cobra.MinimumNArgs(1),
406		RunE: func(cmd *cobra.Command, args []string) error {
407			url, err := url.Parse(args[0])
408			if err != nil {
409				return err
410			}
411
412			req, err := http.NewRequest(method, url.String(), nil)
413			if err != nil {
414				return err
415			}
416
417			if data != "" {
418				req.Body = io.NopCloser(strings.NewReader(data))
419			}
420
421			if verbose {
422				fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())
423			}
424
425			for _, header := range headers {
426				parts := strings.SplitN(header, ":", 2)
427				if len(parts) != 2 {
428					return fmt.Errorf("invalid header: %s", header)
429				}
430				req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
431			}
432
433			if userInfo := url.User; userInfo != nil {
434				password, _ := userInfo.Password()
435				req.SetBasicAuth(userInfo.Username(), password)
436			}
437
438			if verbose {
439				for key, values := range req.Header {
440					for _, value := range values {
441						fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)
442					}
443				}
444			}
445
446			resp, err := http.DefaultClient.Do(req)
447			if err != nil {
448				return err
449			}
450
451			if verbose {
452				fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)
453				for key, values := range resp.Header {
454					for _, value := range values {
455						fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)
456					}
457				}
458			}
459
460			defer resp.Body.Close()
461			buf, err := io.ReadAll(resp.Body)
462			if err != nil {
463				return err
464			}
465
466			cmd.Print(string(buf))
467
468			return nil
469		},
470	}
471
472	cmd.SetArgs(args)
473	cmd.SetOut(ts.Stdout())
474	cmd.SetErr(ts.Stderr())
475
476	cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")
477	cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")
478	cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")
479	cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")
480
481	check(ts, cmd.Execute(), neg)
482}
483
484func cmdEnsureServerRunning(ts *testscript.TestScript, neg bool, args []string) { //nolint:unparam
485	if len(args) < 1 {
486		ts.Fatalf("Must supply a TCP port of one of the services to connect to. " +
487			"These are set as env vars as they are randomized. " +
488			"Example usage: \"cmdensureserverrunning SSH_PORT\"\n" +
489			"Valid values for the env var: SSH_PORT|HTTP_PORT|GIT_PORT|STATS_PORT")
490	}
491
492	port := ts.Getenv(args[0])
493
494	// verify that the server is up
495	addr := net.JoinHostPort("localhost", port)
496	for {
497		conn, _ := net.DialTimeout(
498			"tcp",
499			addr,
500			time.Second,
501		)
502		if conn != nil {
503			ts.Logf("Server is running on port: %s", port)
504			conn.Close()
505			break
506		}
507	}
508}
509
510func cmdEnsureServerNotRunning(ts *testscript.TestScript, neg bool, args []string) { //nolint:unparam
511	if len(args) < 1 {
512		ts.Fatalf("Must supply a TCP port of one of the services to connect to. " +
513			"These are set as env vars as they are randomized. " +
514			"Example usage: \"cmdensureservernotrunning SSH_PORT\"\n" +
515			"Valid values for the env var: SSH_PORT|HTTP_PORT|GIT_PORT|STATS_PORT")
516	}
517
518	port := ts.Getenv(args[0])
519
520	// verify that the server is not up
521	addr := net.JoinHostPort("localhost", port)
522	for {
523		conn, _ := net.DialTimeout(
524			"tcp",
525			addr,
526			time.Second,
527		)
528		if conn != nil {
529			ts.Fatalf("server is running on port %s while it should not be running", port)
530			conn.Close()
531		}
532		break
533	}
534}
535
536func cmdStopserver(ts *testscript.TestScript, neg bool, args []string) { //nolint:unparam
537	// stop the server
538	resp, err := http.DefaultClient.Head(fmt.Sprintf("%s/__stop", ts.Getenv("SOFT_SERVE_HTTP_PUBLIC_URL")))
539	check(ts, err, neg)
540	resp.Body.Close()
541	time.Sleep(time.Second * 2) // Allow some time for the server to stop
542}
543
544func setupPostgres(t testscript.T, cfg *config.Config) (error, func()) {
545	// Indicates postgres
546	// Create a disposable database
547	rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
548	dbName := fmt.Sprintf("softserve_test_%d", rnd.Int63())
549	dbDsn := cfg.DB.DataSource
550	if dbDsn == "" {
551		cfg.DB.DataSource = "postgres://postgres@localhost:5432/postgres?sslmode=disable"
552	}
553
554	dbUrl, err := url.Parse(cfg.DB.DataSource)
555	if err != nil {
556		return err, nil
557	}
558
559	scheme := dbUrl.Scheme
560	if scheme == "" {
561		scheme = "postgres"
562	}
563
564	host := dbUrl.Hostname()
565	if host == "" {
566		host = "localhost"
567	}
568
569	connInfo := fmt.Sprintf("host=%s sslmode=disable", host)
570	username := dbUrl.User.Username()
571	if username != "" {
572		connInfo += fmt.Sprintf(" user=%s", username)
573		password, ok := dbUrl.User.Password()
574		if ok {
575			username = fmt.Sprintf("%s:%s", username, password)
576			connInfo += fmt.Sprintf(" password=%s", password)
577		}
578		username = fmt.Sprintf("%s@", username)
579	} else {
580		connInfo += " user=postgres"
581		username = "postgres@"
582	}
583
584	port := dbUrl.Port()
585	if port != "" {
586		connInfo += fmt.Sprintf(" port=%s", port)
587		port = fmt.Sprintf(":%s", port)
588	}
589
590	cfg.DB.DataSource = fmt.Sprintf("%s://%s%s%s/%s?sslmode=disable",
591		scheme,
592		username,
593		host,
594		port,
595		dbName,
596	)
597
598	// Create the database
599	dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
600	if err != nil {
601		return err, nil
602	}
603
604	if _, err := dbx.Exec("CREATE DATABASE " + dbName); err != nil {
605		return err, nil
606	}
607
608	return nil, func() {
609		dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
610		if err != nil {
611			t.Fatal("failed to open database", dbName, err)
612		}
613
614		if _, err := dbx.Exec("DROP DATABASE " + dbName); err != nil {
615			t.Fatal("failed to drop database", dbName, err)
616		}
617	}
618}