script_test.go

  1package testscript
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"flag"
  8	"fmt"
  9	"io"
 10	"math/rand"
 11	"net"
 12	"net/http"
 13	"net/url"
 14	"os"
 15	"os/exec"
 16	"path/filepath"
 17	"runtime"
 18	"strconv"
 19	"strings"
 20	"testing"
 21	"time"
 22
 23	"github.com/charmbracelet/keygen"
 24	"github.com/charmbracelet/soft-serve/pkg/config"
 25	"github.com/charmbracelet/soft-serve/pkg/db"
 26	"github.com/charmbracelet/soft-serve/pkg/test"
 27	"github.com/rogpeppe/go-internal/testscript"
 28	"github.com/spf13/cobra"
 29	"golang.org/x/crypto/ssh"
 30)
 31
 32var (
 33	update  = flag.Bool("update", false, "update script files")
 34	binPath string
 35)
 36
 37func PrepareBuildCommand(binPath string) *exec.Cmd {
 38	_, disableRaceSet := os.LookupEnv("SOFT_SERVE_DISABLE_RACE_CHECKS")
 39	if disableRaceSet {
 40		// don't add the -race flag
 41		return exec.Command("go", "build", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft")) //nolint:noctx
 42	}
 43	return exec.Command("go", "build", "-race", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft")) //nolint:noctx
 44}
 45
 46func TestMain(m *testing.M) {
 47	tmp, err := os.MkdirTemp("", "soft-serve*")
 48	if err != nil {
 49		fmt.Fprintf(os.Stderr, "failed to create temporary directory: %s", err)
 50		os.Exit(1)
 51	}
 52	defer os.RemoveAll(tmp)
 53
 54	binPath = filepath.Join(tmp, "soft")
 55	if runtime.GOOS == "windows" {
 56		binPath += ".exe"
 57	}
 58
 59	// Build the soft binary with -cover flag.
 60	cmd := PrepareBuildCommand(binPath)
 61	if err := cmd.Run(); err != nil {
 62		fmt.Fprintf(os.Stderr, "failed to build soft-serve binary: %s", err)
 63		os.Exit(1)
 64	}
 65
 66	// Run tests
 67	os.Exit(m.Run())
 68}
 69
 70func TestScript(t *testing.T) {
 71	flag.Parse()
 72
 73	mkkey := func(name string) (string, *keygen.SSHKeyPair) {
 74		path := filepath.Join(t.TempDir(), name)
 75		pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
 76		if err != nil {
 77			t.Fatal(err)
 78		}
 79		return path, pair
 80	}
 81
 82	admin1Key, admin1 := mkkey("admin1")
 83	_, admin2 := mkkey("admin2")
 84	user1Key, user1 := mkkey("user1")
 85	attackerKey, attacker := mkkey("attacker")
 86	attackerSigner := &maliciousSigner{
 87		publicKey: admin1.PublicKey(),
 88	}
 89
 90	testscript.Run(t, testscript.Params{
 91		Dir:                 "./testdata/",
 92		UpdateScripts:       *update,
 93		RequireExplicitExec: true,
 94		Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
 95			"soft":                   cmdSoft("admin", admin1.Signer()),
 96			"usoft":                  cmdSoft("user1", user1.Signer()),
 97			"attacksoft":             cmdSoft("attacker", attackerSigner, attacker.Signer()),
 98			"git":                    cmdGit(admin1Key),
 99			"ugit":                   cmdGit(user1Key),
100			"agit":                   cmdGit(attackerKey),
101			"curl":                   cmdCurl,
102			"mkfile":                 cmdMkfile,
103			"envfile":                cmdEnvfile,
104			"readfile":               cmdReadfile,
105			"dos2unix":               cmdDos2Unix,
106			"new-webhook":            cmdNewWebhook,
107			"ensureserverrunning":    cmdEnsureServerRunning,
108			"ensureservernotrunning": cmdEnsureServerNotRunning,
109			"stopserver":             cmdStopserver,
110			"ui":                     cmdUI(admin1.Signer()),
111			"uui":                    cmdUI(user1.Signer()),
112		},
113		Setup: func(e *testscript.Env) error {
114			// Add binPath to PATH
115			e.Setenv("PATH", fmt.Sprintf("%s%c%s", filepath.Dir(binPath), os.PathListSeparator, e.Getenv("PATH")))
116
117			data := t.TempDir()
118			sshPort := test.RandomPort()
119			sshListen := fmt.Sprintf("localhost:%d", sshPort)
120			gitPort := test.RandomPort()
121			gitListen := fmt.Sprintf("localhost:%d", gitPort)
122			httpPort := test.RandomPort()
123			httpListen := fmt.Sprintf("localhost:%d", httpPort)
124			statsPort := test.RandomPort()
125			statsListen := fmt.Sprintf("localhost:%d", statsPort)
126			serverName := "Test Soft Serve"
127
128			e.Setenv("DATA_PATH", data)
129			e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
130			e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
131			e.Setenv("STATS_PORT", fmt.Sprintf("%d", statsPort))
132			e.Setenv("GIT_PORT", fmt.Sprintf("%d", gitPort))
133			e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
134			e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
135			e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
136			e.Setenv("ATTACKER_AUTHORIZED_KEY", attacker.AuthorizedKey())
137			e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
138			e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
139
140			// This is used to set up test specific configuration and http endpoints
141			e.Setenv("SOFT_SERVE_TESTRUN", "1")
142
143			// This will disable the default lipgloss renderer colors
144			e.Setenv("SOFT_SERVE_NO_COLOR", "1")
145
146			// Soft Serve debug environment variables
147			for _, env := range []string{
148				"SOFT_SERVE_DEBUG",
149				"SOFT_SERVE_VERBOSE",
150			} {
151				if v, ok := os.LookupEnv(env); ok {
152					e.Setenv(env, v)
153				}
154			}
155
156			// TODO: test different configs
157			cfg := config.DefaultConfig()
158			cfg.DataPath = data
159			cfg.Name = serverName
160			cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
161			cfg.SSH.ListenAddr = sshListen
162			cfg.SSH.PublicURL = "ssh://" + sshListen
163			cfg.Git.ListenAddr = gitListen
164			cfg.HTTP.ListenAddr = httpListen
165			cfg.HTTP.PublicURL = "http://" + httpListen
166			cfg.Stats.ListenAddr = statsListen
167			cfg.LFS.Enabled = true
168
169			// Parse os SOFT_SERVE environment variables
170			if err := cfg.ParseEnv(); err != nil {
171				return err
172			}
173
174			// Override the database data source if we're using postgres
175			// so we can create a temporary database for the tests.
176			if cfg.DB.Driver == "postgres" {
177				cleanup, err := setupPostgres(e.T(), cfg)
178				if err != nil {
179					return err
180				}
181				if cleanup != nil {
182					e.Defer(cleanup)
183				}
184			}
185
186			for _, env := range cfg.Environ() {
187				parts := strings.SplitN(env, "=", 2)
188				if len(parts) != 2 {
189					e.T().Fatal("invalid environment variable", env)
190				}
191				e.Setenv(parts[0], parts[1])
192			}
193
194			return nil
195		},
196	})
197}
198
199func cmdSoft(user string, keys ...ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
200	return func(ts *testscript.TestScript, neg bool, args []string) {
201		cli, err := ssh.Dial(
202			"tcp",
203			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
204			&ssh.ClientConfig{
205				User:            user,
206				Auth:            []ssh.AuthMethod{ssh.PublicKeys(keys...)},
207				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
208			},
209		)
210		ts.Check(err)
211		defer cli.Close()
212
213		sess, err := cli.NewSession()
214		ts.Check(err)
215		defer sess.Close()
216
217		sess.Stdout = ts.Stdout()
218		sess.Stderr = ts.Stderr()
219
220		check(ts, sess.Run(strings.Join(args, " ")), neg)
221	}
222}
223
224func cmdUI(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
225	return func(ts *testscript.TestScript, neg bool, args []string) {
226		if len(args) < 1 {
227			ts.Fatalf("usage: ui <quoted string input>")
228			return
229		}
230
231		cli, err := ssh.Dial(
232			"tcp",
233			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
234			&ssh.ClientConfig{
235				User:            "git",
236				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
237				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
238			},
239		)
240		check(ts, err, neg)
241		defer cli.Close()
242
243		sess, err := cli.NewSession()
244		check(ts, err, neg)
245		defer sess.Close()
246
247		// XXX: this is a hack to make the UI tests work
248		// cmp command always complains about an extra newline
249		// in the output
250		defer ts.Stdout().Write([]byte("\n"))
251
252		sess.Stdout = ts.Stdout()
253		sess.Stderr = ts.Stderr()
254
255		stdin, err := sess.StdinPipe()
256		check(ts, err, neg)
257
258		err = sess.RequestPty("dumb", 40, 80, ssh.TerminalModes{})
259		check(ts, err, neg)
260		check(ts, sess.Start(""), neg)
261
262		in, err := strconv.Unquote(args[0])
263		check(ts, err, neg)
264		reader := strings.NewReader(in)
265		go func() {
266			defer stdin.Close()
267			for {
268				r, _, err := reader.ReadRune()
269				if err == io.EOF {
270					break
271				}
272				check(ts, err, neg)
273				_, _ = io.WriteString(stdin, string(r))
274
275				// Wait for the UI to process the input
276				time.Sleep(100 * time.Millisecond)
277			}
278		}()
279
280		check(ts, sess.Wait(), neg)
281	}
282}
283
284func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
285	if neg {
286		ts.Fatalf("unsupported: ! dos2unix")
287	}
288	if len(args) < 1 {
289		ts.Fatalf("usage: dos2unix paths...")
290	}
291	for _, arg := range args {
292		filename := ts.MkAbs(arg)
293		data, err := os.ReadFile(filename)
294		if err != nil {
295			ts.Fatalf("%s: %v", filename, err)
296		}
297
298		// Replace all '\r\n' with '\n'.
299		data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
300
301		if err := os.WriteFile(filename, data, 0o644); err != nil {
302			ts.Fatalf("%s: %v", filename, err)
303		}
304	}
305}
306
307var sshConfig = `
308Host *
309  UserKnownHostsFile %q
310  StrictHostKeyChecking no
311  IdentityAgent none
312  IdentitiesOnly yes
313  ServerAliveInterval 60
314`
315
316func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
317	return func(ts *testscript.TestScript, neg bool, args []string) {
318		ts.Check(os.WriteFile(
319			ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
320			[]byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
321			0o600,
322		))
323		sshArgs := []string{
324			"-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
325			"-i", filepath.ToSlash(key),
326		}
327		ts.Setenv(
328			"GIT_SSH_COMMAND",
329			strings.Join(append([]string{"ssh"}, sshArgs...), " "),
330		)
331		// Disable git prompting for credentials.
332		ts.Setenv("GIT_TERMINAL_PROMPT", "0")
333		args = append([]string{
334			"-c", "user.email=john@example.com",
335			"-c", "user.name=John Doe",
336		}, args...)
337		check(ts, ts.Exec("git", args...), neg)
338	}
339}
340
341func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
342	if len(args) < 2 {
343		ts.Fatalf("usage: mkfile path content")
344	}
345	check(ts, os.WriteFile(
346		ts.MkAbs(args[0]),
347		[]byte(strings.Join(args[1:], " ")),
348		0o644,
349	), neg)
350}
351
352func check(ts *testscript.TestScript, err error, neg bool) {
353	if neg && err == nil {
354		ts.Fatalf("expected error, got nil")
355	}
356	if !neg {
357		ts.Check(err)
358	}
359}
360
361func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
362	ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
363}
364
365func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {
366	if len(args) < 1 {
367		ts.Fatalf("usage: envfile key=file...")
368	}
369
370	for _, arg := range args {
371		parts := strings.SplitN(arg, "=", 2)
372		if len(parts) != 2 {
373			ts.Fatalf("usage: envfile key=file...")
374		}
375		key := parts[0]
376		file := parts[1]
377		ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
378	}
379}
380
381func cmdNewWebhook(ts *testscript.TestScript, neg bool, args []string) {
382	type webhookSite struct {
383		UUID string `json:"uuid"`
384	}
385
386	if len(args) != 1 {
387		ts.Fatalf("usage: new-webhook <env-name>")
388	}
389
390	const whSite = "https://webhook.site"
391	req, err := http.NewRequest(http.MethodPost, whSite+"/token", nil) //nolint:noctx
392	check(ts, err, neg)
393
394	resp, err := http.DefaultClient.Do(req)
395	check(ts, err, neg)
396
397	defer resp.Body.Close()
398	var site webhookSite
399	check(ts, json.NewDecoder(resp.Body).Decode(&site), neg)
400
401	ts.Setenv(args[0], whSite+"/"+site.UUID)
402}
403
404func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {
405	var verbose bool
406	var headers []string
407	var data string
408	method := http.MethodGet
409
410	cmd := &cobra.Command{
411		Use:  "curl",
412		Args: cobra.MinimumNArgs(1),
413		RunE: func(cmd *cobra.Command, args []string) error {
414			url, err := url.Parse(args[0])
415			if err != nil {
416				return err
417			}
418
419			req, err := http.NewRequest(method, url.String(), nil) //nolint:noctx
420			if err != nil {
421				return err
422			}
423
424			if data != "" {
425				req.Body = io.NopCloser(strings.NewReader(data))
426			}
427
428			if verbose {
429				fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())
430			}
431
432			for _, header := range headers {
433				parts := strings.SplitN(header, ":", 2)
434				if len(parts) != 2 {
435					return fmt.Errorf("invalid header: %s", header)
436				}
437				req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
438			}
439
440			if userInfo := url.User; userInfo != nil {
441				password, _ := userInfo.Password()
442				req.SetBasicAuth(userInfo.Username(), password)
443			}
444
445			if verbose {
446				for key, values := range req.Header {
447					for _, value := range values {
448						fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)
449					}
450				}
451			}
452
453			resp, err := http.DefaultClient.Do(req)
454			if err != nil {
455				return err
456			}
457
458			if verbose {
459				fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)
460				for key, values := range resp.Header {
461					for _, value := range values {
462						fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)
463					}
464				}
465			}
466
467			defer resp.Body.Close()
468			buf, err := io.ReadAll(resp.Body)
469			if err != nil {
470				return err
471			}
472
473			cmd.Print(string(buf))
474
475			return nil
476		},
477	}
478
479	cmd.SetArgs(args)
480	cmd.SetOut(ts.Stdout())
481	cmd.SetErr(ts.Stderr())
482
483	cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")
484	cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")
485	cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")
486	cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")
487
488	check(ts, cmd.Execute(), neg)
489}
490
491func cmdEnsureServerRunning(ts *testscript.TestScript, neg bool, args []string) {
492	if len(args) < 1 {
493		ts.Fatalf("Must supply a TCP port of one of the services to connect to. " +
494			"These are set as env vars as they are randomized. " +
495			"Example usage: \"cmdensureserverrunning SSH_PORT\"\n" +
496			"Valid values for the env var: SSH_PORT|HTTP_PORT|GIT_PORT|STATS_PORT")
497	}
498
499	port := ts.Getenv(args[0])
500
501	// verify that the server is up
502	addr := net.JoinHostPort("localhost", port)
503	for {
504		conn, _ := net.DialTimeout( //nolint:noctx
505			"tcp",
506			addr,
507			time.Second,
508		)
509		if conn != nil {
510			ts.Logf("Server is running on port: %s", port)
511			conn.Close()
512			break
513		}
514	}
515}
516
517func cmdEnsureServerNotRunning(ts *testscript.TestScript, neg bool, args []string) {
518	if len(args) < 1 {
519		ts.Fatalf("Must supply a TCP port of one of the services to connect to. " +
520			"These are set as env vars as they are randomized. " +
521			"Example usage: \"cmdensureservernotrunning SSH_PORT\"\n" +
522			"Valid values for the env var: SSH_PORT|HTTP_PORT|GIT_PORT|STATS_PORT")
523	}
524
525	port := ts.Getenv(args[0])
526
527	// verify that the server is not up
528	addr := net.JoinHostPort("localhost", port)
529	conn, _ := net.DialTimeout( //nolint:noctx
530		"tcp",
531		addr,
532		time.Second,
533	)
534	if conn != nil {
535		ts.Fatalf("server is running on port %s while it should not be running", port)
536		conn.Close()
537	}
538}
539
540func cmdStopserver(ts *testscript.TestScript, neg bool, args []string) {
541	// stop the server
542	resp, err := http.DefaultClient.Head(fmt.Sprintf("%s/__stop", ts.Getenv("SOFT_SERVE_HTTP_PUBLIC_URL"))) //nolint:noctx
543	check(ts, err, neg)
544	resp.Body.Close()
545	time.Sleep(time.Second * 2) // Allow some time for the server to stop
546}
547
548func setupPostgres(t testscript.T, cfg *config.Config) (func(), error) {
549	// Indicates postgres
550	// Create a disposable database
551	rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
552	dbName := fmt.Sprintf("softserve_test_%d", rnd.Int63())
553	dbDsn := cfg.DB.DataSource
554	if dbDsn == "" {
555		cfg.DB.DataSource = "postgres://postgres@localhost:5432/postgres?sslmode=disable"
556	}
557
558	dbUrl, err := url.Parse(cfg.DB.DataSource)
559	if err != nil {
560		return nil, err
561	}
562
563	scheme := dbUrl.Scheme
564	if scheme == "" {
565		scheme = "postgres"
566	}
567
568	host := dbUrl.Hostname()
569	if host == "" {
570		host = "localhost"
571	}
572
573	connInfo := fmt.Sprintf("host=%s sslmode=disable", host)
574	username := dbUrl.User.Username()
575	if username != "" {
576		connInfo += fmt.Sprintf(" user=%s", username)
577		password, ok := dbUrl.User.Password()
578		if ok {
579			username = fmt.Sprintf("%s:%s", username, password)
580			connInfo += fmt.Sprintf(" password=%s", password)
581		}
582		username = fmt.Sprintf("%s@", username)
583	} else {
584		connInfo += " user=postgres"
585		username = "postgres@"
586	}
587
588	port := dbUrl.Port()
589	if port != "" {
590		connInfo += fmt.Sprintf(" port=%s", port)
591		port = fmt.Sprintf(":%s", port)
592	}
593
594	cfg.DB.DataSource = fmt.Sprintf("%s://%s%s%s/%s?sslmode=disable",
595		scheme,
596		username,
597		host,
598		port,
599		dbName,
600	)
601
602	// Create the database
603	dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
604	if err != nil {
605		return nil, err
606	}
607
608	if _, err := dbx.ExecContext(context.TODO(), "CREATE DATABASE "+dbName); err != nil {
609		return nil, err
610	}
611
612	return func() {
613		dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
614		if err != nil {
615			t.Fatal("failed to open database", dbName, err)
616		}
617
618		if _, err := dbx.ExecContext(context.TODO(), "DROP DATABASE "+dbName); err != nil {
619			t.Fatal("failed to drop database", dbName, err)
620		}
621	}, nil
622}
623
624type maliciousSigner struct {
625	publicKey ssh.PublicKey
626}
627
628var _ ssh.Signer = (*maliciousSigner)(nil)
629
630// PublicKey implements ssh.Signer.
631func (m *maliciousSigner) PublicKey() ssh.PublicKey {
632	return m.publicKey
633}
634
635// Sign implements ssh.Signer.
636func (m *maliciousSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
637	// The attacker doesn't know how to sign the data without a private key.
638	return &ssh.Signature{}, nil
639}