script_test.go

  1package testscript
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"flag"
  8	"fmt"
  9	"io"
 10	"math/rand"
 11	"net"
 12	"net/http"
 13	"net/url"
 14	"os"
 15	"os/exec"
 16	"path/filepath"
 17	"runtime"
 18	"strconv"
 19	"strings"
 20	"testing"
 21	"time"
 22
 23	"github.com/charmbracelet/keygen"
 24	"github.com/charmbracelet/soft-serve/pkg/config"
 25	"github.com/charmbracelet/soft-serve/pkg/db"
 26	"github.com/charmbracelet/soft-serve/pkg/test"
 27	"github.com/rogpeppe/go-internal/testscript"
 28	"github.com/spf13/cobra"
 29	"golang.org/x/crypto/ssh"
 30)
 31
 32var (
 33	update  = flag.Bool("update", false, "update script files")
 34	binPath string
 35)
 36
 37func TestMain(m *testing.M) {
 38	tmp, err := os.MkdirTemp("", "soft-serve*")
 39	if err != nil {
 40		fmt.Fprintf(os.Stderr, "failed to create temporary directory: %s", err)
 41		os.Exit(1)
 42	}
 43	defer os.RemoveAll(tmp)
 44
 45	binPath = filepath.Join(tmp, "soft")
 46	if runtime.GOOS == "windows" {
 47		binPath += ".exe"
 48	}
 49
 50	// Build the soft binary with -cover flag.
 51	cmd := exec.Command("go", "build", "-race", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft"))
 52	if err := cmd.Run(); err != nil {
 53		fmt.Fprintf(os.Stderr, "failed to build soft-serve binary: %s", err)
 54		os.Exit(1)
 55	}
 56
 57	// Run tests
 58	os.Exit(m.Run())
 59}
 60
 61func TestScript(t *testing.T) {
 62	flag.Parse()
 63
 64	mkkey := func(name string) (string, *keygen.SSHKeyPair) {
 65		path := filepath.Join(t.TempDir(), name)
 66		pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
 67		if err != nil {
 68			t.Fatal(err)
 69		}
 70		return path, pair
 71	}
 72
 73	admin1Key, admin1 := mkkey("admin1")
 74	_, admin2 := mkkey("admin2")
 75	user1Key, user1 := mkkey("user1")
 76
 77	testscript.Run(t, testscript.Params{
 78		Dir:                 "./testdata/",
 79		UpdateScripts:       *update,
 80		RequireExplicitExec: true,
 81		Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
 82			"soft":                   cmdSoft("admin", admin1.Signer()),
 83			"usoft":                  cmdSoft("user1", user1.Signer()),
 84			"git":                    cmdGit(admin1Key),
 85			"ugit":                   cmdGit(user1Key),
 86			"curl":                   cmdCurl,
 87			"mkfile":                 cmdMkfile,
 88			"envfile":                cmdEnvfile,
 89			"readfile":               cmdReadfile,
 90			"dos2unix":               cmdDos2Unix,
 91			"new-webhook":            cmdNewWebhook,
 92			"ensureserverrunning":    cmdEnsureServerRunning,
 93			"ensureservernotrunning": cmdEnsureServerNotRunning,
 94			"stopserver":             cmdStopserver,
 95			"ui":                     cmdUI(admin1.Signer()),
 96			"uui":                    cmdUI(user1.Signer()),
 97		},
 98		Setup: func(e *testscript.Env) error {
 99			// Add binPath to PATH
100			e.Setenv("PATH", fmt.Sprintf("%s%c%s", filepath.Dir(binPath), os.PathListSeparator, e.Getenv("PATH")))
101
102			data := t.TempDir()
103			sshPort := test.RandomPort()
104			sshListen := fmt.Sprintf("localhost:%d", sshPort)
105			gitPort := test.RandomPort()
106			gitListen := fmt.Sprintf("localhost:%d", gitPort)
107			httpPort := test.RandomPort()
108			httpListen := fmt.Sprintf("localhost:%d", httpPort)
109			statsPort := test.RandomPort()
110			statsListen := fmt.Sprintf("localhost:%d", statsPort)
111			serverName := "Test Soft Serve"
112
113			e.Setenv("DATA_PATH", data)
114			e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
115			e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
116			e.Setenv("STATS_PORT", fmt.Sprintf("%d", statsPort))
117			e.Setenv("GIT_PORT", fmt.Sprintf("%d", gitPort))
118			e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
119			e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
120			e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
121			e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
122			e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
123
124			// This is used to set up test specific configuration and http endpoints
125			e.Setenv("SOFT_SERVE_TESTRUN", "1")
126
127			// This will disable the default lipgloss renderer colors
128			e.Setenv("SOFT_SERVE_NO_COLOR", "1")
129
130			// Soft Serve debug environment variables
131			for _, env := range []string{
132				"SOFT_SERVE_DEBUG",
133				"SOFT_SERVE_VERBOSE",
134			} {
135				if v, ok := os.LookupEnv(env); ok {
136					e.Setenv(env, v)
137				}
138			}
139
140			// TODO: test different configs
141			cfg := config.DefaultConfig()
142			cfg.DataPath = data
143			cfg.Name = serverName
144			cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
145			cfg.SSH.ListenAddr = sshListen
146			cfg.SSH.PublicURL = "ssh://" + sshListen
147			cfg.Git.ListenAddr = gitListen
148			cfg.HTTP.ListenAddr = httpListen
149			cfg.HTTP.PublicURL = "http://" + httpListen
150			cfg.Stats.ListenAddr = statsListen
151			cfg.LFS.Enabled = true
152
153			// Parse os SOFT_SERVE environment variables
154			if err := cfg.ParseEnv(); err != nil {
155				return err
156			}
157
158			// Override the database data source if we're using postgres
159			// so we can create a temporary database for the tests.
160			if cfg.DB.Driver == "postgres" {
161				err, cleanup := setupPostgres(e.T(), cfg)
162				if err != nil {
163					return err
164				}
165				if cleanup != nil {
166					e.Defer(cleanup)
167				}
168			}
169
170			for _, env := range cfg.Environ() {
171				parts := strings.SplitN(env, "=", 2)
172				if len(parts) != 2 {
173					e.T().Fatal("invalid environment variable", env)
174				}
175				e.Setenv(parts[0], parts[1])
176			}
177
178			return nil
179		},
180	})
181}
182
183func cmdSoft(user string, key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
184	return func(ts *testscript.TestScript, neg bool, args []string) {
185		cli, err := ssh.Dial(
186			"tcp",
187			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
188			&ssh.ClientConfig{
189				User:            user,
190				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
191				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
192			},
193		)
194		ts.Check(err)
195		defer cli.Close()
196
197		sess, err := cli.NewSession()
198		ts.Check(err)
199		defer sess.Close()
200
201		sess.Stdout = ts.Stdout()
202		sess.Stderr = ts.Stderr()
203
204		check(ts, sess.Run(strings.Join(args, " ")), neg)
205	}
206}
207
208func cmdUI(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
209	return func(ts *testscript.TestScript, neg bool, args []string) {
210		if len(args) < 1 {
211			ts.Fatalf("usage: ui <quoted string input>")
212			return
213		}
214
215		cli, err := ssh.Dial(
216			"tcp",
217			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
218			&ssh.ClientConfig{
219				User:            "git",
220				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
221				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
222			},
223		)
224		check(ts, err, neg)
225		defer cli.Close()
226
227		sess, err := cli.NewSession()
228		check(ts, err, neg)
229		defer sess.Close()
230
231		// XXX: this is a hack to make the UI tests work
232		// cmp command always complains about an extra newline
233		// in the output
234		defer ts.Stdout().Write([]byte("\n"))
235
236		sess.Stdout = ts.Stdout()
237		sess.Stderr = ts.Stderr()
238
239		stdin, err := sess.StdinPipe()
240		check(ts, err, neg)
241
242		err = sess.RequestPty("dumb", 40, 80, ssh.TerminalModes{})
243		check(ts, err, neg)
244		check(ts, sess.Start(""), neg)
245
246		in, err := strconv.Unquote(args[0])
247		check(ts, err, neg)
248		reader := strings.NewReader(in)
249		go func() {
250			defer stdin.Close()
251			for {
252				r, _, err := reader.ReadRune()
253				if err == io.EOF {
254					break
255				}
256				check(ts, err, neg)
257				stdin.Write([]byte(string(r))) // nolint: errcheck
258
259				// Wait for the UI to process the input
260				time.Sleep(100 * time.Millisecond)
261			}
262		}()
263
264		check(ts, sess.Wait(), neg)
265	}
266}
267
268// P.S. Windows sucks!
269func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
270	if neg {
271		ts.Fatalf("unsupported: ! dos2unix")
272	}
273	if len(args) < 1 {
274		ts.Fatalf("usage: dos2unix paths...")
275	}
276	for _, arg := range args {
277		filename := ts.MkAbs(arg)
278		data, err := os.ReadFile(filename)
279		if err != nil {
280			ts.Fatalf("%s: %v", filename, err)
281		}
282
283		// Replace all '\r\n' with '\n'.
284		data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
285
286		if err := os.WriteFile(filename, data, 0o644); err != nil {
287			ts.Fatalf("%s: %v", filename, err)
288		}
289	}
290}
291
292var sshConfig = `
293Host *
294  UserKnownHostsFile %q
295  StrictHostKeyChecking no
296  IdentityAgent none
297  IdentitiesOnly yes
298  ServerAliveInterval 60
299`
300
301func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
302	return func(ts *testscript.TestScript, neg bool, args []string) {
303		ts.Check(os.WriteFile(
304			ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
305			[]byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
306			0o600,
307		))
308		sshArgs := []string{
309			"-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
310			"-i", filepath.ToSlash(key),
311		}
312		ts.Setenv(
313			"GIT_SSH_COMMAND",
314			strings.Join(append([]string{"ssh"}, sshArgs...), " "),
315		)
316		// Disable git prompting for credentials.
317		ts.Setenv("GIT_TERMINAL_PROMPT", "0")
318		args = append([]string{
319			"-c", "user.email=john@example.com",
320			"-c", "user.name=John Doe",
321		}, args...)
322		check(ts, ts.Exec("git", args...), neg)
323	}
324}
325
326func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
327	if len(args) < 2 {
328		ts.Fatalf("usage: mkfile path content")
329	}
330	check(ts, os.WriteFile(
331		ts.MkAbs(args[0]),
332		[]byte(strings.Join(args[1:], " ")),
333		0o644,
334	), neg)
335}
336
337func check(ts *testscript.TestScript, err error, neg bool) {
338	if neg && err == nil {
339		ts.Fatalf("expected error, got nil")
340	}
341	if !neg {
342		ts.Check(err)
343	}
344}
345
346func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
347	ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
348}
349
350func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {
351	if len(args) < 1 {
352		ts.Fatalf("usage: envfile key=file...")
353	}
354
355	for _, arg := range args {
356		parts := strings.SplitN(arg, "=", 2)
357		if len(parts) != 2 {
358			ts.Fatalf("usage: envfile key=file...")
359		}
360		key := parts[0]
361		file := parts[1]
362		ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
363	}
364}
365
366func cmdNewWebhook(ts *testscript.TestScript, neg bool, args []string) {
367	type webhookSite struct {
368		UUID string `json:"uuid"`
369	}
370
371	if len(args) != 1 {
372		ts.Fatalf("usage: new-webhook <env-name>")
373	}
374
375	const whSite = "https://webhook.site"
376	req, err := http.NewRequest(http.MethodPost, whSite+"/token", nil)
377	check(ts, err, neg)
378
379	resp, err := http.DefaultClient.Do(req)
380	check(ts, err, neg)
381
382	defer resp.Body.Close()
383	var site webhookSite
384	check(ts, json.NewDecoder(resp.Body).Decode(&site), neg)
385
386	ts.Setenv(args[0], whSite+"/"+site.UUID)
387}
388
389func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {
390	var verbose bool
391	var headers []string
392	var data string
393	method := http.MethodGet
394
395	cmd := &cobra.Command{
396		Use:  "curl",
397		Args: cobra.MinimumNArgs(1),
398		RunE: func(cmd *cobra.Command, args []string) error {
399			url, err := url.Parse(args[0])
400			if err != nil {
401				return err
402			}
403
404			req, err := http.NewRequest(method, url.String(), nil)
405			if err != nil {
406				return err
407			}
408
409			if data != "" {
410				req.Body = io.NopCloser(strings.NewReader(data))
411			}
412
413			if verbose {
414				fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())
415			}
416
417			for _, header := range headers {
418				parts := strings.SplitN(header, ":", 2)
419				if len(parts) != 2 {
420					return fmt.Errorf("invalid header: %s", header)
421				}
422				req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
423			}
424
425			if userInfo := url.User; userInfo != nil {
426				password, _ := userInfo.Password()
427				req.SetBasicAuth(userInfo.Username(), password)
428			}
429
430			if verbose {
431				for key, values := range req.Header {
432					for _, value := range values {
433						fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)
434					}
435				}
436			}
437
438			resp, err := http.DefaultClient.Do(req)
439			if err != nil {
440				return err
441			}
442
443			if verbose {
444				fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)
445				for key, values := range resp.Header {
446					for _, value := range values {
447						fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)
448					}
449				}
450			}
451
452			defer resp.Body.Close()
453			buf, err := io.ReadAll(resp.Body)
454			if err != nil {
455				return err
456			}
457
458			cmd.Print(string(buf))
459
460			return nil
461		},
462	}
463
464	cmd.SetArgs(args)
465	cmd.SetOut(ts.Stdout())
466	cmd.SetErr(ts.Stderr())
467
468	cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")
469	cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")
470	cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")
471	cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")
472
473	check(ts, cmd.Execute(), neg)
474}
475
476func cmdEnsureServerRunning(ts *testscript.TestScript, neg bool, args []string) {
477	if len(args) < 1 {
478		ts.Fatalf("Must supply a TCP port of one of the services to connect to. " +
479			"These are set as env vars as they are randomized. " +
480			"Example usage: \"cmdensureserverrunning SSH_PORT\"\n" +
481			"Valid values for the env var: SSH_PORT|HTTP_PORT|GIT_PORT|STATS_PORT")
482	}
483
484	port := ts.Getenv(args[0])
485
486	// verify that the server is up
487	addr := net.JoinHostPort("localhost", port)
488	for {
489		conn, _ := net.DialTimeout(
490			"tcp",
491			addr,
492			time.Second,
493		)
494		if conn != nil {
495			ts.Logf("Server is running on port: %s", port)
496			conn.Close()
497			break
498		}
499	}
500}
501
502func cmdEnsureServerNotRunning(ts *testscript.TestScript, neg bool, args []string) {
503	if len(args) < 1 {
504		ts.Fatalf("Must supply a TCP port of one of the services to connect to. " +
505			"These are set as env vars as they are randomized. " +
506			"Example usage: \"cmdensureservernotrunning SSH_PORT\"\n" +
507			"Valid values for the env var: SSH_PORT|HTTP_PORT|GIT_PORT|STATS_PORT")
508	}
509
510	port := ts.Getenv(args[0])
511
512	// verify that the server is not up
513	addr := net.JoinHostPort("localhost", port)
514	for {
515		conn, _ := net.DialTimeout(
516			"tcp",
517			addr,
518			time.Second,
519		)
520		if conn != nil {
521			ts.Fatalf("server is running on port %s while it should not be running", port)
522			conn.Close()
523		}
524		break
525	}
526}
527
528func cmdStopserver(ts *testscript.TestScript, neg bool, args []string) {
529	// stop the server
530	resp, err := http.DefaultClient.Head(fmt.Sprintf("%s/__stop", ts.Getenv("SOFT_SERVE_HTTP_PUBLIC_URL")))
531	check(ts, err, neg)
532	resp.Body.Close()
533	time.Sleep(time.Second * 2) // Allow some time for the server to stop
534}
535
536func setupPostgres(t testscript.T, cfg *config.Config) (error, func()) {
537	// Indicates postgres
538	// Create a disposable database
539	rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
540	dbName := fmt.Sprintf("softserve_test_%d", rnd.Int63())
541	dbDsn := cfg.DB.DataSource
542	if dbDsn == "" {
543		cfg.DB.DataSource = "postgres://postgres@localhost:5432/postgres?sslmode=disable"
544	}
545
546	dbUrl, err := url.Parse(cfg.DB.DataSource)
547	if err != nil {
548		return err, nil
549	}
550
551	scheme := dbUrl.Scheme
552	if scheme == "" {
553		scheme = "postgres"
554	}
555
556	host := dbUrl.Hostname()
557	if host == "" {
558		host = "localhost"
559	}
560
561	connInfo := fmt.Sprintf("host=%s sslmode=disable", host)
562	username := dbUrl.User.Username()
563	if username != "" {
564		connInfo += fmt.Sprintf(" user=%s", username)
565		password, ok := dbUrl.User.Password()
566		if ok {
567			username = fmt.Sprintf("%s:%s", username, password)
568			connInfo += fmt.Sprintf(" password=%s", password)
569		}
570		username = fmt.Sprintf("%s@", username)
571	} else {
572		connInfo += " user=postgres"
573		username = "postgres@"
574	}
575
576	port := dbUrl.Port()
577	if port != "" {
578		connInfo += fmt.Sprintf(" port=%s", port)
579		port = fmt.Sprintf(":%s", port)
580	}
581
582	cfg.DB.DataSource = fmt.Sprintf("%s://%s%s%s/%s?sslmode=disable",
583		scheme,
584		username,
585		host,
586		port,
587		dbName,
588	)
589
590	// Create the database
591	dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
592	if err != nil {
593		return err, nil
594	}
595
596	if _, err := dbx.Exec("CREATE DATABASE " + dbName); err != nil {
597		return err, nil
598	}
599
600	return nil, func() {
601		dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
602		if err != nil {
603			t.Fatal("failed to open database", dbName, err)
604		}
605
606		if _, err := dbx.Exec("DROP DATABASE " + dbName); err != nil {
607			t.Fatal("failed to drop database", dbName, err)
608		}
609	}
610}