script_test.go

  1package testscript
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"flag"
  8	"fmt"
  9	"io"
 10	"math/rand"
 11	"net"
 12	"net/http"
 13	"net/url"
 14	"os"
 15	"os/exec"
 16	"path/filepath"
 17	"runtime"
 18	"strconv"
 19	"strings"
 20	"testing"
 21	"time"
 22
 23	"github.com/charmbracelet/keygen"
 24	"github.com/charmbracelet/soft-serve/pkg/config"
 25	"github.com/charmbracelet/soft-serve/pkg/db"
 26	"github.com/charmbracelet/soft-serve/pkg/test"
 27	"github.com/rogpeppe/go-internal/testscript"
 28	"github.com/spf13/cobra"
 29	"golang.org/x/crypto/ssh"
 30)
 31
 32var (
 33	update  = flag.Bool("update", false, "update script files")
 34	binPath string
 35)
 36
 37func TestMain(m *testing.M) {
 38	tmp, err := os.MkdirTemp("", "soft-serve*")
 39	if err != nil {
 40		fmt.Fprintf(os.Stderr, "failed to create temporary directory: %s", err)
 41		os.Exit(1)
 42	}
 43	defer os.RemoveAll(tmp)
 44
 45	binPath = filepath.Join(tmp, "soft")
 46	if runtime.GOOS == "windows" {
 47		binPath += ".exe"
 48	}
 49
 50	// Build the soft binary with -cover flag.
 51	cmd := exec.Command("go", "build", "-race", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft"))
 52	if err := cmd.Run(); err != nil {
 53		fmt.Fprintf(os.Stderr, "failed to build soft-serve binary: %s", err)
 54		os.Exit(1)
 55	}
 56
 57	// Run tests
 58	os.Exit(m.Run())
 59
 60	// Add binPath to PATH
 61	os.Setenv("PATH", fmt.Sprintf("%s%c%s", os.Getenv("PATH"), os.PathListSeparator, filepath.Dir(binPath)))
 62}
 63
 64func TestScript(t *testing.T) {
 65	flag.Parse()
 66
 67	mkkey := func(name string) (string, *keygen.SSHKeyPair) {
 68		path := filepath.Join(t.TempDir(), name)
 69		pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
 70		if err != nil {
 71			t.Fatal(err)
 72		}
 73		return path, pair
 74	}
 75
 76	key, admin1 := mkkey("admin1")
 77	_, admin2 := mkkey("admin2")
 78	_, user1 := mkkey("user1")
 79	_, user2 := mkkey("user2")
 80
 81	testscript.Run(t, testscript.Params{
 82		Dir:                 "./testdata/",
 83		UpdateScripts:       *update,
 84		RequireExplicitExec: true,
 85		Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
 86			"soft":          cmdSoft(admin1.Signer()),
 87			"usoft":         cmdSoft(user1.Signer()),
 88			"git":           cmdGit(key),
 89			"curl":          cmdCurl,
 90			"mkfile":        cmdMkfile,
 91			"envfile":       cmdEnvfile,
 92			"readfile":      cmdReadfile,
 93			"dos2unix":      cmdDos2Unix,
 94			"new-webhook":   cmdNewWebhook,
 95			"waitforserver": cmdWaitforserver,
 96			"stopserver":    cmdStopserver,
 97			"ui":            cmdUI(admin1.Signer()),
 98			"uui":           cmdUI(user1.Signer()),
 99		},
100		Setup: func(e *testscript.Env) error {
101			// Add binPath to PATH
102			e.Setenv("PATH", fmt.Sprintf("%s%c%s", filepath.Dir(binPath), os.PathListSeparator, e.Getenv("PATH")))
103
104			data := t.TempDir()
105			sshPort := test.RandomPort()
106			sshListen := fmt.Sprintf("localhost:%d", sshPort)
107			gitPort := test.RandomPort()
108			gitListen := fmt.Sprintf("localhost:%d", gitPort)
109			httpPort := test.RandomPort()
110			httpListen := fmt.Sprintf("localhost:%d", httpPort)
111			statsPort := test.RandomPort()
112			statsListen := fmt.Sprintf("localhost:%d", statsPort)
113			serverName := "Test Soft Serve"
114
115			e.Setenv("DATA_PATH", data)
116			e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
117			e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
118			e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
119			e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
120			e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
121			e.Setenv("USER2_AUTHORIZED_KEY", user2.AuthorizedKey())
122			e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
123			e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
124
125			// This is used to set up test specific configuration and http endpoints
126			e.Setenv("SOFT_SERVE_TESTRUN", "1")
127
128			// This will disable the default lipgloss renderer colors
129			e.Setenv("SOFT_SERVE_NO_COLOR", "1")
130
131			// Soft Serve debug environment variables
132			for _, env := range []string{
133				"SOFT_SERVE_DEBUG",
134				"SOFT_SERVE_VERBOSE",
135			} {
136				if v, ok := os.LookupEnv(env); ok {
137					e.Setenv(env, v)
138				}
139			}
140
141			// TODO: test different configs
142			cfg := config.DefaultConfig()
143			cfg.DataPath = data
144			cfg.Name = serverName
145			cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
146			cfg.SSH.ListenAddr = sshListen
147			cfg.SSH.PublicURL = "ssh://" + sshListen
148			cfg.Git.ListenAddr = gitListen
149			cfg.HTTP.ListenAddr = httpListen
150			cfg.HTTP.PublicURL = "http://" + httpListen
151			cfg.Stats.ListenAddr = statsListen
152			cfg.LFS.Enabled = true
153
154			// Parse os SOFT_SERVE environment variables
155			if err := cfg.ParseEnv(); err != nil {
156				return err
157			}
158
159			// Override the database data source if we're using postgres
160			// so we can create a temporary database for the tests.
161			if cfg.DB.Driver == "postgres" {
162				err, cleanup := setupPostgres(e.T(), cfg)
163				if err != nil {
164					return err
165				}
166				if cleanup != nil {
167					e.Defer(cleanup)
168				}
169			}
170
171			for _, env := range cfg.Environ() {
172				parts := strings.SplitN(env, "=", 2)
173				if len(parts) != 2 {
174					e.T().Fatal("invalid environment variable", env)
175				}
176				e.Setenv(parts[0], parts[1])
177			}
178
179			return nil
180		},
181	})
182}
183
184func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
185	return func(ts *testscript.TestScript, neg bool, args []string) {
186		cli, err := ssh.Dial(
187			"tcp",
188			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
189			&ssh.ClientConfig{
190				User:            "admin",
191				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
192				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
193			},
194		)
195		ts.Check(err)
196		defer cli.Close()
197
198		sess, err := cli.NewSession()
199		ts.Check(err)
200		defer sess.Close()
201
202		sess.Stdout = ts.Stdout()
203		sess.Stderr = ts.Stderr()
204
205		check(ts, sess.Run(strings.Join(args, " ")), neg)
206	}
207}
208
209func cmdUI(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
210	return func(ts *testscript.TestScript, neg bool, args []string) {
211		if len(args) < 1 {
212			ts.Fatalf("usage: ui <quoted string input>")
213			return
214		}
215
216		cli, err := ssh.Dial(
217			"tcp",
218			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
219			&ssh.ClientConfig{
220				User:            "git",
221				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
222				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
223			},
224		)
225		check(ts, err, neg)
226		defer cli.Close()
227
228		sess, err := cli.NewSession()
229		check(ts, err, neg)
230		defer sess.Close()
231
232		// XXX: this is a hack to make the UI tests work
233		// cmp command always complains about an extra newline
234		// in the output
235		defer ts.Stdout().Write([]byte("\n"))
236
237		sess.Stdout = ts.Stdout()
238		sess.Stderr = ts.Stderr()
239
240		stdin, err := sess.StdinPipe()
241		check(ts, err, neg)
242
243		in, err := strconv.Unquote(args[0])
244		check(ts, err, neg)
245		reader := strings.NewReader(in)
246		go func() {
247			defer stdin.Close()
248			for {
249				r, _, err := reader.ReadRune()
250				if err == io.EOF {
251					break
252				}
253				check(ts, err, neg)
254				stdin.Write([]byte(string(r))) // nolint: errcheck
255
256				// Wait for the UI to process the input
257				time.Sleep(100 * time.Millisecond)
258			}
259		}()
260
261		err = sess.RequestPty("dumb", 40, 80, ssh.TerminalModes{})
262		check(ts, err, neg)
263		check(ts, sess.Run(""), neg)
264	}
265}
266
267// P.S. Windows sucks!
268func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
269	if neg {
270		ts.Fatalf("unsupported: ! dos2unix")
271	}
272	if len(args) < 1 {
273		ts.Fatalf("usage: dos2unix paths...")
274	}
275	for _, arg := range args {
276		filename := ts.MkAbs(arg)
277		data, err := os.ReadFile(filename)
278		if err != nil {
279			ts.Fatalf("%s: %v", filename, err)
280		}
281
282		// Replace all '\r\n' with '\n'.
283		data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
284
285		if err := os.WriteFile(filename, data, 0o644); err != nil {
286			ts.Fatalf("%s: %v", filename, err)
287		}
288	}
289}
290
291var sshConfig = `
292Host *
293  UserKnownHostsFile %q
294  StrictHostKeyChecking no
295  IdentityAgent none
296  IdentitiesOnly yes
297  ServerAliveInterval 60
298`
299
300func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
301	return func(ts *testscript.TestScript, neg bool, args []string) {
302		ts.Check(os.WriteFile(
303			ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
304			[]byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
305			0o600,
306		))
307		sshArgs := []string{
308			"-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
309			"-i", filepath.ToSlash(key),
310		}
311		ts.Setenv(
312			"GIT_SSH_COMMAND",
313			strings.Join(append([]string{"ssh"}, sshArgs...), " "),
314		)
315		// Disable git prompting for credentials.
316		ts.Setenv("GIT_TERMINAL_PROMPT", "0")
317		args = append([]string{
318			"-c", "user.email=john@example.com",
319			"-c", "user.name=John Doe",
320		}, args...)
321		check(ts, ts.Exec("git", args...), neg)
322	}
323}
324
325func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
326	if len(args) < 2 {
327		ts.Fatalf("usage: mkfile path content")
328	}
329	check(ts, os.WriteFile(
330		ts.MkAbs(args[0]),
331		[]byte(strings.Join(args[1:], " ")),
332		0o644,
333	), neg)
334}
335
336func check(ts *testscript.TestScript, err error, neg bool) {
337	if neg && err == nil {
338		ts.Fatalf("expected error, got nil")
339	}
340	if !neg {
341		ts.Check(err)
342	}
343}
344
345func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
346	ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
347}
348
349func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {
350	if len(args) < 1 {
351		ts.Fatalf("usage: envfile key=file...")
352	}
353
354	for _, arg := range args {
355		parts := strings.SplitN(arg, "=", 2)
356		if len(parts) != 2 {
357			ts.Fatalf("usage: envfile key=file...")
358		}
359		key := parts[0]
360		file := parts[1]
361		ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
362	}
363}
364
365func cmdNewWebhook(ts *testscript.TestScript, neg bool, args []string) {
366	type webhookSite struct {
367		UUID string `json:"uuid"`
368	}
369
370	if len(args) != 1 {
371		ts.Fatalf("usage: new-webhook <env-name>")
372	}
373
374	const whSite = "https://webhook.site"
375	req, err := http.NewRequest(http.MethodPost, whSite+"/token", nil)
376	check(ts, err, neg)
377
378	resp, err := http.DefaultClient.Do(req)
379	check(ts, err, neg)
380
381	defer resp.Body.Close()
382	var site webhookSite
383	check(ts, json.NewDecoder(resp.Body).Decode(&site), neg)
384
385	ts.Setenv(args[0], whSite+"/"+site.UUID)
386}
387
388func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {
389	var verbose bool
390	var headers []string
391	var data string
392	method := http.MethodGet
393
394	cmd := &cobra.Command{
395		Use:  "curl",
396		Args: cobra.MinimumNArgs(1),
397		RunE: func(cmd *cobra.Command, args []string) error {
398			url, err := url.Parse(args[0])
399			if err != nil {
400				return err
401			}
402
403			req, err := http.NewRequest(method, url.String(), nil)
404			if err != nil {
405				return err
406			}
407
408			if data != "" {
409				req.Body = io.NopCloser(strings.NewReader(data))
410			}
411
412			if verbose {
413				fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())
414			}
415
416			for _, header := range headers {
417				parts := strings.SplitN(header, ":", 2)
418				if len(parts) != 2 {
419					return fmt.Errorf("invalid header: %s", header)
420				}
421				req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
422			}
423
424			if userInfo := url.User; userInfo != nil {
425				password, _ := userInfo.Password()
426				req.SetBasicAuth(userInfo.Username(), password)
427			}
428
429			if verbose {
430				for key, values := range req.Header {
431					for _, value := range values {
432						fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)
433					}
434				}
435			}
436
437			resp, err := http.DefaultClient.Do(req)
438			if err != nil {
439				return err
440			}
441
442			if verbose {
443				fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)
444				for key, values := range resp.Header {
445					for _, value := range values {
446						fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)
447					}
448				}
449			}
450
451			defer resp.Body.Close()
452			buf, err := io.ReadAll(resp.Body)
453			if err != nil {
454				return err
455			}
456
457			cmd.Print(string(buf))
458
459			return nil
460		},
461	}
462
463	cmd.SetArgs(args)
464	cmd.SetOut(ts.Stdout())
465	cmd.SetErr(ts.Stderr())
466
467	cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")
468	cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")
469	cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")
470	cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")
471
472	check(ts, cmd.Execute(), neg)
473}
474
475func cmdWaitforserver(ts *testscript.TestScript, neg bool, args []string) {
476	// wait until the server is up
477	for {
478		conn, _ := net.DialTimeout(
479			"tcp",
480			net.JoinHostPort("localhost", fmt.Sprintf("%s", ts.Getenv("SSH_PORT"))),
481			time.Second,
482		)
483		if conn != nil {
484			conn.Close()
485			break
486		}
487	}
488}
489
490func cmdStopserver(ts *testscript.TestScript, neg bool, args []string) {
491	// stop the server
492	resp, err := http.DefaultClient.Head(fmt.Sprintf("%s/__stop", ts.Getenv("SOFT_SERVE_HTTP_PUBLIC_URL")))
493	check(ts, err, neg)
494	defer resp.Body.Close()
495	time.Sleep(time.Second * 2) // Allow some time for the server to stop
496}
497
498func setupPostgres(t testscript.T, cfg *config.Config) (error, func()) {
499	// Indicates postgres
500	// Create a disposable database
501	rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
502	dbName := fmt.Sprintf("softserve_test_%d", rnd.Int63())
503	dbDsn := cfg.DB.DataSource
504	if dbDsn == "" {
505		cfg.DB.DataSource = "postgres://postgres@localhost:5432/postgres?sslmode=disable"
506	}
507
508	dbUrl, err := url.Parse(cfg.DB.DataSource)
509	if err != nil {
510		return err, nil
511	}
512
513	scheme := dbUrl.Scheme
514	if scheme == "" {
515		scheme = "postgres"
516	}
517
518	host := dbUrl.Hostname()
519	if host == "" {
520		host = "localhost"
521	}
522
523	connInfo := fmt.Sprintf("host=%s sslmode=disable", host)
524	username := dbUrl.User.Username()
525	if username != "" {
526		connInfo += fmt.Sprintf(" user=%s", username)
527		password, ok := dbUrl.User.Password()
528		if ok {
529			username = fmt.Sprintf("%s:%s", username, password)
530			connInfo += fmt.Sprintf(" password=%s", password)
531		}
532		username = fmt.Sprintf("%s@", username)
533	} else {
534		connInfo += " user=postgres"
535		username = "postgres@"
536	}
537
538	port := dbUrl.Port()
539	if port != "" {
540		connInfo += fmt.Sprintf(" port=%s", port)
541		port = fmt.Sprintf(":%s", port)
542	}
543
544	cfg.DB.DataSource = fmt.Sprintf("%s://%s%s%s/%s?sslmode=disable",
545		scheme,
546		username,
547		host,
548		port,
549		dbName,
550	)
551
552	// Create the database
553	dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
554	if err != nil {
555		return err, nil
556	}
557
558	if _, err := dbx.Exec("CREATE DATABASE " + dbName); err != nil {
559		return err, nil
560	}
561
562	return nil, func() {
563		dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
564		if err != nil {
565			t.Fatal("failed to open database", dbName, err)
566		}
567
568		if _, err := dbx.Exec("DROP DATABASE " + dbName); err != nil {
569			t.Fatal("failed to drop database", dbName, err)
570		}
571	}
572}