1package testscript
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"flag"
  8	"fmt"
  9	"io"
 10	"math/rand"
 11	"net"
 12	"net/http"
 13	"net/url"
 14	"os"
 15	"os/exec"
 16	"path/filepath"
 17	"runtime"
 18	"strconv"
 19	"strings"
 20	"testing"
 21	"time"
 22
 23	"github.com/charmbracelet/keygen"
 24	"github.com/charmbracelet/soft-serve/pkg/config"
 25	"github.com/charmbracelet/soft-serve/pkg/db"
 26	"github.com/charmbracelet/soft-serve/pkg/test"
 27	"github.com/rogpeppe/go-internal/testscript"
 28	"github.com/spf13/cobra"
 29	"golang.org/x/crypto/ssh"
 30)
 31
 32var (
 33	update  = flag.Bool("update", false, "update script files")
 34	binPath string
 35)
 36
 37func TestMain(m *testing.M) {
 38	tmp, err := os.MkdirTemp("", "soft-serve*")
 39	if err != nil {
 40		fmt.Fprintf(os.Stderr, "failed to create temporary directory: %s", err)
 41		os.Exit(1)
 42	}
 43	defer os.RemoveAll(tmp)
 44
 45	binPath = filepath.Join(tmp, "soft")
 46	if runtime.GOOS == "windows" {
 47		binPath += ".exe"
 48	}
 49
 50	// Build the soft binary with -cover flag.
 51	cmd := exec.Command("go", "build", "-race", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft"))
 52	if err := cmd.Run(); err != nil {
 53		fmt.Fprintf(os.Stderr, "failed to build soft-serve binary: %s", err)
 54		os.Exit(1)
 55	}
 56
 57	// Run tests
 58	os.Exit(m.Run())
 59}
 60
 61func TestScript(t *testing.T) {
 62	flag.Parse()
 63
 64	mkkey := func(name string) (string, *keygen.SSHKeyPair) {
 65		path := filepath.Join(t.TempDir(), name)
 66		pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
 67		if err != nil {
 68			t.Fatal(err)
 69		}
 70		return path, pair
 71	}
 72
 73	admin1Key, admin1 := mkkey("admin1")
 74	_, admin2 := mkkey("admin2")
 75	user1Key, user1 := mkkey("user1")
 76
 77	testscript.Run(t, testscript.Params{
 78		Dir:                 "./testdata/",
 79		UpdateScripts:       *update,
 80		RequireExplicitExec: true,
 81		Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
 82			"soft":          cmdSoft("admin", admin1.Signer()),
 83			"usoft":         cmdSoft("user1", user1.Signer()),
 84			"git":           cmdGit(admin1Key),
 85			"ugit":          cmdGit(user1Key),
 86			"curl":          cmdCurl,
 87			"mkfile":        cmdMkfile,
 88			"envfile":       cmdEnvfile,
 89			"readfile":      cmdReadfile,
 90			"dos2unix":      cmdDos2Unix,
 91			"new-webhook":   cmdNewWebhook,
 92			"waitforserver": cmdWaitforserver,
 93			"stopserver":    cmdStopserver,
 94			"ui":            cmdUI(admin1.Signer()),
 95			"uui":           cmdUI(user1.Signer()),
 96		},
 97		Setup: func(e *testscript.Env) error {
 98			// Add binPath to PATH
 99			e.Setenv("PATH", fmt.Sprintf("%s%c%s", filepath.Dir(binPath), os.PathListSeparator, e.Getenv("PATH")))
100
101			data := t.TempDir()
102			sshPort := test.RandomPort()
103			sshListen := fmt.Sprintf("localhost:%d", sshPort)
104			gitPort := test.RandomPort()
105			gitListen := fmt.Sprintf("localhost:%d", gitPort)
106			httpPort := test.RandomPort()
107			httpListen := fmt.Sprintf("localhost:%d", httpPort)
108			statsPort := test.RandomPort()
109			statsListen := fmt.Sprintf("localhost:%d", statsPort)
110			serverName := "Test Soft Serve"
111
112			e.Setenv("DATA_PATH", data)
113			e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
114			e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
115			e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
116			e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
117			e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
118			e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
119			e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
120
121			// This is used to set up test specific configuration and http endpoints
122			e.Setenv("SOFT_SERVE_TESTRUN", "1")
123
124			// This will disable the default lipgloss renderer colors
125			e.Setenv("SOFT_SERVE_NO_COLOR", "1")
126
127			// Soft Serve debug environment variables
128			for _, env := range []string{
129				"SOFT_SERVE_DEBUG",
130				"SOFT_SERVE_VERBOSE",
131			} {
132				if v, ok := os.LookupEnv(env); ok {
133					e.Setenv(env, v)
134				}
135			}
136
137			// TODO: test different configs
138			cfg := config.DefaultConfig()
139			cfg.DataPath = data
140			cfg.Name = serverName
141			cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
142			cfg.SSH.ListenAddr = sshListen
143			cfg.SSH.PublicURL = "ssh://" + sshListen
144			cfg.Git.ListenAddr = gitListen
145			cfg.HTTP.ListenAddr = httpListen
146			cfg.HTTP.PublicURL = "http://" + httpListen
147			cfg.Stats.ListenAddr = statsListen
148			cfg.LFS.Enabled = true
149
150			// Parse os SOFT_SERVE environment variables
151			if err := cfg.ParseEnv(); err != nil {
152				return err
153			}
154
155			// Override the database data source if we're using postgres
156			// so we can create a temporary database for the tests.
157			if cfg.DB.Driver == "postgres" {
158				err, cleanup := setupPostgres(e.T(), cfg)
159				if err != nil {
160					return err
161				}
162				if cleanup != nil {
163					e.Defer(cleanup)
164				}
165			}
166
167			for _, env := range cfg.Environ() {
168				parts := strings.SplitN(env, "=", 2)
169				if len(parts) != 2 {
170					e.T().Fatal("invalid environment variable", env)
171				}
172				e.Setenv(parts[0], parts[1])
173			}
174
175			return nil
176		},
177	})
178}
179
180func cmdSoft(user string, key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
181	return func(ts *testscript.TestScript, neg bool, args []string) {
182		cli, err := ssh.Dial(
183			"tcp",
184			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
185			&ssh.ClientConfig{
186				User:            user,
187				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
188				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
189			},
190		)
191		ts.Check(err)
192		defer cli.Close()
193
194		sess, err := cli.NewSession()
195		ts.Check(err)
196		defer sess.Close()
197
198		sess.Stdout = ts.Stdout()
199		sess.Stderr = ts.Stderr()
200
201		check(ts, sess.Run(strings.Join(args, " ")), neg)
202	}
203}
204
205func cmdUI(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
206	return func(ts *testscript.TestScript, neg bool, args []string) {
207		if len(args) < 1 {
208			ts.Fatalf("usage: ui <quoted string input>")
209			return
210		}
211
212		cli, err := ssh.Dial(
213			"tcp",
214			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
215			&ssh.ClientConfig{
216				User:            "git",
217				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
218				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
219			},
220		)
221		check(ts, err, neg)
222		defer cli.Close()
223
224		sess, err := cli.NewSession()
225		check(ts, err, neg)
226		defer sess.Close()
227
228		// XXX: this is a hack to make the UI tests work
229		// cmp command always complains about an extra newline
230		// in the output
231		defer ts.Stdout().Write([]byte("\n"))
232
233		sess.Stdout = ts.Stdout()
234		sess.Stderr = ts.Stderr()
235
236		stdin, err := sess.StdinPipe()
237		check(ts, err, neg)
238
239		err = sess.RequestPty("dumb", 40, 80, ssh.TerminalModes{})
240		check(ts, err, neg)
241		check(ts, sess.Start(""), neg)
242
243		in, err := strconv.Unquote(args[0])
244		check(ts, err, neg)
245		reader := strings.NewReader(in)
246		go func() {
247			defer stdin.Close()
248			for {
249				r, _, err := reader.ReadRune()
250				if err == io.EOF {
251					break
252				}
253				check(ts, err, neg)
254				stdin.Write([]byte(string(r))) // nolint: errcheck
255
256				// Wait for the UI to process the input
257				time.Sleep(100 * time.Millisecond)
258			}
259		}()
260
261		check(ts, sess.Wait(), neg)
262	}
263}
264
265// P.S. Windows sucks!
266func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
267	if neg {
268		ts.Fatalf("unsupported: ! dos2unix")
269	}
270	if len(args) < 1 {
271		ts.Fatalf("usage: dos2unix paths...")
272	}
273	for _, arg := range args {
274		filename := ts.MkAbs(arg)
275		data, err := os.ReadFile(filename)
276		if err != nil {
277			ts.Fatalf("%s: %v", filename, err)
278		}
279
280		// Replace all '\r\n' with '\n'.
281		data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
282
283		if err := os.WriteFile(filename, data, 0o644); err != nil {
284			ts.Fatalf("%s: %v", filename, err)
285		}
286	}
287}
288
289var sshConfig = `
290Host *
291  UserKnownHostsFile %q
292  StrictHostKeyChecking no
293  IdentityAgent none
294  IdentitiesOnly yes
295  ServerAliveInterval 60
296`
297
298func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
299	return func(ts *testscript.TestScript, neg bool, args []string) {
300		ts.Check(os.WriteFile(
301			ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
302			[]byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
303			0o600,
304		))
305		sshArgs := []string{
306			"-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
307			"-i", filepath.ToSlash(key),
308		}
309		ts.Setenv(
310			"GIT_SSH_COMMAND",
311			strings.Join(append([]string{"ssh"}, sshArgs...), " "),
312		)
313		// Disable git prompting for credentials.
314		ts.Setenv("GIT_TERMINAL_PROMPT", "0")
315		args = append([]string{
316			"-c", "user.email=john@example.com",
317			"-c", "user.name=John Doe",
318		}, args...)
319		check(ts, ts.Exec("git", args...), neg)
320	}
321}
322
323func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
324	if len(args) < 2 {
325		ts.Fatalf("usage: mkfile path content")
326	}
327	check(ts, os.WriteFile(
328		ts.MkAbs(args[0]),
329		[]byte(strings.Join(args[1:], " ")),
330		0o644,
331	), neg)
332}
333
334func check(ts *testscript.TestScript, err error, neg bool) {
335	if neg && err == nil {
336		ts.Fatalf("expected error, got nil")
337	}
338	if !neg {
339		ts.Check(err)
340	}
341}
342
343func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
344	ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
345}
346
347func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {
348	if len(args) < 1 {
349		ts.Fatalf("usage: envfile key=file...")
350	}
351
352	for _, arg := range args {
353		parts := strings.SplitN(arg, "=", 2)
354		if len(parts) != 2 {
355			ts.Fatalf("usage: envfile key=file...")
356		}
357		key := parts[0]
358		file := parts[1]
359		ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
360	}
361}
362
363func cmdNewWebhook(ts *testscript.TestScript, neg bool, args []string) {
364	type webhookSite struct {
365		UUID string `json:"uuid"`
366	}
367
368	if len(args) != 1 {
369		ts.Fatalf("usage: new-webhook <env-name>")
370	}
371
372	const whSite = "https://webhook.site"
373	req, err := http.NewRequest(http.MethodPost, whSite+"/token", nil)
374	check(ts, err, neg)
375
376	resp, err := http.DefaultClient.Do(req)
377	check(ts, err, neg)
378
379	defer resp.Body.Close()
380	var site webhookSite
381	check(ts, json.NewDecoder(resp.Body).Decode(&site), neg)
382
383	ts.Setenv(args[0], whSite+"/"+site.UUID)
384}
385
386func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {
387	var verbose bool
388	var headers []string
389	var data string
390	method := http.MethodGet
391
392	cmd := &cobra.Command{
393		Use:  "curl",
394		Args: cobra.MinimumNArgs(1),
395		RunE: func(cmd *cobra.Command, args []string) error {
396			url, err := url.Parse(args[0])
397			if err != nil {
398				return err
399			}
400
401			req, err := http.NewRequest(method, url.String(), nil)
402			if err != nil {
403				return err
404			}
405
406			if data != "" {
407				req.Body = io.NopCloser(strings.NewReader(data))
408			}
409
410			if verbose {
411				fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())
412			}
413
414			for _, header := range headers {
415				parts := strings.SplitN(header, ":", 2)
416				if len(parts) != 2 {
417					return fmt.Errorf("invalid header: %s", header)
418				}
419				req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
420			}
421
422			if userInfo := url.User; userInfo != nil {
423				password, _ := userInfo.Password()
424				req.SetBasicAuth(userInfo.Username(), password)
425			}
426
427			if verbose {
428				for key, values := range req.Header {
429					for _, value := range values {
430						fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)
431					}
432				}
433			}
434
435			resp, err := http.DefaultClient.Do(req)
436			if err != nil {
437				return err
438			}
439
440			if verbose {
441				fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)
442				for key, values := range resp.Header {
443					for _, value := range values {
444						fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)
445					}
446				}
447			}
448
449			defer resp.Body.Close()
450			buf, err := io.ReadAll(resp.Body)
451			if err != nil {
452				return err
453			}
454
455			cmd.Print(string(buf))
456
457			return nil
458		},
459	}
460
461	cmd.SetArgs(args)
462	cmd.SetOut(ts.Stdout())
463	cmd.SetErr(ts.Stderr())
464
465	cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")
466	cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")
467	cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")
468	cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")
469
470	check(ts, cmd.Execute(), neg)
471}
472
473func cmdWaitforserver(ts *testscript.TestScript, neg bool, args []string) {
474	// wait until the server is up
475	addr := net.JoinHostPort("localhost", ts.Getenv("SSH_PORT"))
476	for {
477		conn, _ := net.DialTimeout(
478			"tcp",
479			addr,
480			time.Second,
481		)
482		if conn != nil {
483			conn.Close()
484			break
485		}
486	}
487}
488
489func cmdStopserver(ts *testscript.TestScript, neg bool, args []string) {
490	// stop the server
491	resp, err := http.DefaultClient.Head(fmt.Sprintf("%s/__stop", ts.Getenv("SOFT_SERVE_HTTP_PUBLIC_URL")))
492	check(ts, err, neg)
493	resp.Body.Close()
494	time.Sleep(time.Second * 2) // Allow some time for the server to stop
495}
496
497func setupPostgres(t testscript.T, cfg *config.Config) (error, func()) {
498	// Indicates postgres
499	// Create a disposable database
500	rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
501	dbName := fmt.Sprintf("softserve_test_%d", rnd.Int63())
502	dbDsn := cfg.DB.DataSource
503	if dbDsn == "" {
504		cfg.DB.DataSource = "postgres://postgres@localhost:5432/postgres?sslmode=disable"
505	}
506
507	dbUrl, err := url.Parse(cfg.DB.DataSource)
508	if err != nil {
509		return err, nil
510	}
511
512	scheme := dbUrl.Scheme
513	if scheme == "" {
514		scheme = "postgres"
515	}
516
517	host := dbUrl.Hostname()
518	if host == "" {
519		host = "localhost"
520	}
521
522	connInfo := fmt.Sprintf("host=%s sslmode=disable", host)
523	username := dbUrl.User.Username()
524	if username != "" {
525		connInfo += fmt.Sprintf(" user=%s", username)
526		password, ok := dbUrl.User.Password()
527		if ok {
528			username = fmt.Sprintf("%s:%s", username, password)
529			connInfo += fmt.Sprintf(" password=%s", password)
530		}
531		username = fmt.Sprintf("%s@", username)
532	} else {
533		connInfo += " user=postgres"
534		username = "postgres@"
535	}
536
537	port := dbUrl.Port()
538	if port != "" {
539		connInfo += fmt.Sprintf(" port=%s", port)
540		port = fmt.Sprintf(":%s", port)
541	}
542
543	cfg.DB.DataSource = fmt.Sprintf("%s://%s%s%s/%s?sslmode=disable",
544		scheme,
545		username,
546		host,
547		port,
548		dbName,
549	)
550
551	// Create the database
552	dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
553	if err != nil {
554		return err, nil
555	}
556
557	if _, err := dbx.Exec("CREATE DATABASE " + dbName); err != nil {
558		return err, nil
559	}
560
561	return nil, func() {
562		dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
563		if err != nil {
564			t.Fatal("failed to open database", dbName, err)
565		}
566
567		if _, err := dbx.Exec("DROP DATABASE " + dbName); err != nil {
568			t.Fatal("failed to drop database", dbName, err)
569		}
570	}
571}