1package testscript
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"flag"
  8	"fmt"
  9	"io"
 10	"math/rand"
 11	"net"
 12	"net/http"
 13	"net/url"
 14	"os"
 15	"os/exec"
 16	"path/filepath"
 17	"runtime"
 18	"strconv"
 19	"strings"
 20	"testing"
 21	"time"
 22
 23	"github.com/charmbracelet/keygen"
 24	"github.com/charmbracelet/soft-serve/pkg/config"
 25	"github.com/charmbracelet/soft-serve/pkg/db"
 26	"github.com/charmbracelet/soft-serve/pkg/test"
 27	"github.com/rogpeppe/go-internal/testscript"
 28	"github.com/spf13/cobra"
 29	"golang.org/x/crypto/ssh"
 30)
 31
 32var (
 33	update  = flag.Bool("update", false, "update script files")
 34	binPath string
 35)
 36
 37func TestMain(m *testing.M) {
 38	tmp, err := os.MkdirTemp("", "soft-serve*")
 39	if err != nil {
 40		fmt.Fprintf(os.Stderr, "failed to create temporary directory: %s", err)
 41		os.Exit(1)
 42	}
 43	defer os.RemoveAll(tmp)
 44
 45	binPath = filepath.Join(tmp, "soft")
 46	if runtime.GOOS == "windows" {
 47		binPath += ".exe"
 48	}
 49
 50	// Build the soft binary with -cover flag.
 51	cmd := exec.Command("go", "build", "-race", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft"))
 52	if err := cmd.Run(); err != nil {
 53		fmt.Fprintf(os.Stderr, "failed to build soft-serve binary: %s", err)
 54		os.Exit(1)
 55	}
 56
 57	// Run tests
 58	os.Exit(m.Run())
 59}
 60
 61func TestScript(t *testing.T) {
 62	flag.Parse()
 63
 64	mkkey := func(name string) (string, *keygen.SSHKeyPair) {
 65		path := filepath.Join(t.TempDir(), name)
 66		pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
 67		if err != nil {
 68			t.Fatal(err)
 69		}
 70		return path, pair
 71	}
 72
 73	key, admin1 := mkkey("admin1")
 74	_, admin2 := mkkey("admin2")
 75	_, user1 := mkkey("user1")
 76
 77	testscript.Run(t, testscript.Params{
 78		Dir:                 "./testdata/",
 79		UpdateScripts:       *update,
 80		RequireExplicitExec: true,
 81		Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
 82			"soft":          cmdSoft("admin", admin1.Signer()),
 83			"usoft":         cmdSoft("user1", user1.Signer()),
 84			"git":           cmdGit(key),
 85			"curl":          cmdCurl,
 86			"mkfile":        cmdMkfile,
 87			"envfile":       cmdEnvfile,
 88			"readfile":      cmdReadfile,
 89			"dos2unix":      cmdDos2Unix,
 90			"new-webhook":   cmdNewWebhook,
 91			"waitforserver": cmdWaitforserver,
 92			"stopserver":    cmdStopserver,
 93			"ui":            cmdUI(admin1.Signer()),
 94			"uui":           cmdUI(user1.Signer()),
 95		},
 96		Setup: func(e *testscript.Env) error {
 97			// Add binPath to PATH
 98			e.Setenv("PATH", fmt.Sprintf("%s%c%s", filepath.Dir(binPath), os.PathListSeparator, e.Getenv("PATH")))
 99
100			data := t.TempDir()
101			sshPort := test.RandomPort()
102			sshListen := fmt.Sprintf("localhost:%d", sshPort)
103			gitPort := test.RandomPort()
104			gitListen := fmt.Sprintf("localhost:%d", gitPort)
105			httpPort := test.RandomPort()
106			httpListen := fmt.Sprintf("localhost:%d", httpPort)
107			statsPort := test.RandomPort()
108			statsListen := fmt.Sprintf("localhost:%d", statsPort)
109			serverName := "Test Soft Serve"
110
111			e.Setenv("DATA_PATH", data)
112			e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
113			e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
114			e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
115			e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
116			e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
117			e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
118			e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
119
120			// This is used to set up test specific configuration and http endpoints
121			e.Setenv("SOFT_SERVE_TESTRUN", "1")
122
123			// This will disable the default lipgloss renderer colors
124			e.Setenv("SOFT_SERVE_NO_COLOR", "1")
125
126			// Soft Serve debug environment variables
127			for _, env := range []string{
128				"SOFT_SERVE_DEBUG",
129				"SOFT_SERVE_VERBOSE",
130			} {
131				if v, ok := os.LookupEnv(env); ok {
132					e.Setenv(env, v)
133				}
134			}
135
136			// TODO: test different configs
137			cfg := config.DefaultConfig()
138			cfg.DataPath = data
139			cfg.Name = serverName
140			cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
141			cfg.SSH.ListenAddr = sshListen
142			cfg.SSH.PublicURL = "ssh://" + sshListen
143			cfg.Git.ListenAddr = gitListen
144			cfg.HTTP.ListenAddr = httpListen
145			cfg.HTTP.PublicURL = "http://" + httpListen
146			cfg.Stats.ListenAddr = statsListen
147			cfg.LFS.Enabled = true
148
149			// Parse os SOFT_SERVE environment variables
150			if err := cfg.ParseEnv(); err != nil {
151				return err
152			}
153
154			// Override the database data source if we're using postgres
155			// so we can create a temporary database for the tests.
156			if cfg.DB.Driver == "postgres" {
157				err, cleanup := setupPostgres(e.T(), cfg)
158				if err != nil {
159					return err
160				}
161				if cleanup != nil {
162					e.Defer(cleanup)
163				}
164			}
165
166			for _, env := range cfg.Environ() {
167				parts := strings.SplitN(env, "=", 2)
168				if len(parts) != 2 {
169					e.T().Fatal("invalid environment variable", env)
170				}
171				e.Setenv(parts[0], parts[1])
172			}
173
174			return nil
175		},
176	})
177}
178
179func cmdSoft(user string, key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
180	return func(ts *testscript.TestScript, neg bool, args []string) {
181		cli, err := ssh.Dial(
182			"tcp",
183			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
184			&ssh.ClientConfig{
185				User:            user,
186				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
187				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
188			},
189		)
190		ts.Check(err)
191		defer cli.Close()
192
193		sess, err := cli.NewSession()
194		ts.Check(err)
195		defer sess.Close()
196
197		sess.Stdout = ts.Stdout()
198		sess.Stderr = ts.Stderr()
199
200		check(ts, sess.Run(strings.Join(args, " ")), neg)
201	}
202}
203
204func cmdUI(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
205	return func(ts *testscript.TestScript, neg bool, args []string) {
206		if len(args) < 1 {
207			ts.Fatalf("usage: ui <quoted string input>")
208			return
209		}
210
211		cli, err := ssh.Dial(
212			"tcp",
213			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
214			&ssh.ClientConfig{
215				User:            "git",
216				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
217				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
218			},
219		)
220		check(ts, err, neg)
221		defer cli.Close()
222
223		sess, err := cli.NewSession()
224		check(ts, err, neg)
225		defer sess.Close()
226
227		// XXX: this is a hack to make the UI tests work
228		// cmp command always complains about an extra newline
229		// in the output
230		defer ts.Stdout().Write([]byte("\n"))
231
232		sess.Stdout = ts.Stdout()
233		sess.Stderr = ts.Stderr()
234
235		stdin, err := sess.StdinPipe()
236		check(ts, err, neg)
237
238		err = sess.RequestPty("dumb", 40, 80, ssh.TerminalModes{})
239		check(ts, err, neg)
240		check(ts, sess.Start(""), neg)
241
242		in, err := strconv.Unquote(args[0])
243		check(ts, err, neg)
244		reader := strings.NewReader(in)
245		go func() {
246			defer stdin.Close()
247			for {
248				r, _, err := reader.ReadRune()
249				if err == io.EOF {
250					break
251				}
252				check(ts, err, neg)
253				stdin.Write([]byte(string(r))) // nolint: errcheck
254
255				// Wait for the UI to process the input
256				time.Sleep(100 * time.Millisecond)
257			}
258		}()
259
260		check(ts, sess.Wait(), neg)
261	}
262}
263
264// P.S. Windows sucks!
265func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
266	if neg {
267		ts.Fatalf("unsupported: ! dos2unix")
268	}
269	if len(args) < 1 {
270		ts.Fatalf("usage: dos2unix paths...")
271	}
272	for _, arg := range args {
273		filename := ts.MkAbs(arg)
274		data, err := os.ReadFile(filename)
275		if err != nil {
276			ts.Fatalf("%s: %v", filename, err)
277		}
278
279		// Replace all '\r\n' with '\n'.
280		data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
281
282		if err := os.WriteFile(filename, data, 0o644); err != nil {
283			ts.Fatalf("%s: %v", filename, err)
284		}
285	}
286}
287
288var sshConfig = `
289Host *
290  UserKnownHostsFile %q
291  StrictHostKeyChecking no
292  IdentityAgent none
293  IdentitiesOnly yes
294  ServerAliveInterval 60
295`
296
297func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
298	return func(ts *testscript.TestScript, neg bool, args []string) {
299		ts.Check(os.WriteFile(
300			ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
301			[]byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
302			0o600,
303		))
304		sshArgs := []string{
305			"-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
306			"-i", filepath.ToSlash(key),
307		}
308		ts.Setenv(
309			"GIT_SSH_COMMAND",
310			strings.Join(append([]string{"ssh"}, sshArgs...), " "),
311		)
312		// Disable git prompting for credentials.
313		ts.Setenv("GIT_TERMINAL_PROMPT", "0")
314		args = append([]string{
315			"-c", "user.email=john@example.com",
316			"-c", "user.name=John Doe",
317		}, args...)
318		check(ts, ts.Exec("git", args...), neg)
319	}
320}
321
322func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
323	if len(args) < 2 {
324		ts.Fatalf("usage: mkfile path content")
325	}
326	check(ts, os.WriteFile(
327		ts.MkAbs(args[0]),
328		[]byte(strings.Join(args[1:], " ")),
329		0o644,
330	), neg)
331}
332
333func check(ts *testscript.TestScript, err error, neg bool) {
334	if neg && err == nil {
335		ts.Fatalf("expected error, got nil")
336	}
337	if !neg {
338		ts.Check(err)
339	}
340}
341
342func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
343	ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
344}
345
346func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {
347	if len(args) < 1 {
348		ts.Fatalf("usage: envfile key=file...")
349	}
350
351	for _, arg := range args {
352		parts := strings.SplitN(arg, "=", 2)
353		if len(parts) != 2 {
354			ts.Fatalf("usage: envfile key=file...")
355		}
356		key := parts[0]
357		file := parts[1]
358		ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
359	}
360}
361
362func cmdNewWebhook(ts *testscript.TestScript, neg bool, args []string) {
363	type webhookSite struct {
364		UUID string `json:"uuid"`
365	}
366
367	if len(args) != 1 {
368		ts.Fatalf("usage: new-webhook <env-name>")
369	}
370
371	const whSite = "https://webhook.site"
372	req, err := http.NewRequest(http.MethodPost, whSite+"/token", nil)
373	check(ts, err, neg)
374
375	resp, err := http.DefaultClient.Do(req)
376	check(ts, err, neg)
377
378	defer resp.Body.Close()
379	var site webhookSite
380	check(ts, json.NewDecoder(resp.Body).Decode(&site), neg)
381
382	ts.Setenv(args[0], whSite+"/"+site.UUID)
383}
384
385func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {
386	var verbose bool
387	var headers []string
388	var data string
389	method := http.MethodGet
390
391	cmd := &cobra.Command{
392		Use:  "curl",
393		Args: cobra.MinimumNArgs(1),
394		RunE: func(cmd *cobra.Command, args []string) error {
395			url, err := url.Parse(args[0])
396			if err != nil {
397				return err
398			}
399
400			req, err := http.NewRequest(method, url.String(), nil)
401			if err != nil {
402				return err
403			}
404
405			if data != "" {
406				req.Body = io.NopCloser(strings.NewReader(data))
407			}
408
409			if verbose {
410				fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())
411			}
412
413			for _, header := range headers {
414				parts := strings.SplitN(header, ":", 2)
415				if len(parts) != 2 {
416					return fmt.Errorf("invalid header: %s", header)
417				}
418				req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
419			}
420
421			if userInfo := url.User; userInfo != nil {
422				password, _ := userInfo.Password()
423				req.SetBasicAuth(userInfo.Username(), password)
424			}
425
426			if verbose {
427				for key, values := range req.Header {
428					for _, value := range values {
429						fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)
430					}
431				}
432			}
433
434			resp, err := http.DefaultClient.Do(req)
435			if err != nil {
436				return err
437			}
438
439			if verbose {
440				fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)
441				for key, values := range resp.Header {
442					for _, value := range values {
443						fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)
444					}
445				}
446			}
447
448			defer resp.Body.Close()
449			buf, err := io.ReadAll(resp.Body)
450			if err != nil {
451				return err
452			}
453
454			cmd.Print(string(buf))
455
456			return nil
457		},
458	}
459
460	cmd.SetArgs(args)
461	cmd.SetOut(ts.Stdout())
462	cmd.SetErr(ts.Stderr())
463
464	cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")
465	cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")
466	cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")
467	cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")
468
469	check(ts, cmd.Execute(), neg)
470}
471
472func cmdWaitforserver(ts *testscript.TestScript, neg bool, args []string) {
473	// wait until the server is up
474	addr := net.JoinHostPort("localhost", ts.Getenv("SSH_PORT"))
475	for {
476		conn, _ := net.DialTimeout(
477			"tcp",
478			addr,
479			time.Second,
480		)
481		if conn != nil {
482			conn.Close()
483			break
484		}
485	}
486}
487
488func cmdStopserver(ts *testscript.TestScript, neg bool, args []string) {
489	// stop the server
490	resp, err := http.DefaultClient.Head(fmt.Sprintf("%s/__stop", ts.Getenv("SOFT_SERVE_HTTP_PUBLIC_URL")))
491	check(ts, err, neg)
492	resp.Body.Close()
493	time.Sleep(time.Second * 2) // Allow some time for the server to stop
494}
495
496func setupPostgres(t testscript.T, cfg *config.Config) (error, func()) {
497	// Indicates postgres
498	// Create a disposable database
499	rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
500	dbName := fmt.Sprintf("softserve_test_%d", rnd.Int63())
501	dbDsn := cfg.DB.DataSource
502	if dbDsn == "" {
503		cfg.DB.DataSource = "postgres://postgres@localhost:5432/postgres?sslmode=disable"
504	}
505
506	dbUrl, err := url.Parse(cfg.DB.DataSource)
507	if err != nil {
508		return err, nil
509	}
510
511	scheme := dbUrl.Scheme
512	if scheme == "" {
513		scheme = "postgres"
514	}
515
516	host := dbUrl.Hostname()
517	if host == "" {
518		host = "localhost"
519	}
520
521	connInfo := fmt.Sprintf("host=%s sslmode=disable", host)
522	username := dbUrl.User.Username()
523	if username != "" {
524		connInfo += fmt.Sprintf(" user=%s", username)
525		password, ok := dbUrl.User.Password()
526		if ok {
527			username = fmt.Sprintf("%s:%s", username, password)
528			connInfo += fmt.Sprintf(" password=%s", password)
529		}
530		username = fmt.Sprintf("%s@", username)
531	} else {
532		connInfo += " user=postgres"
533		username = "postgres@"
534	}
535
536	port := dbUrl.Port()
537	if port != "" {
538		connInfo += fmt.Sprintf(" port=%s", port)
539		port = fmt.Sprintf(":%s", port)
540	}
541
542	cfg.DB.DataSource = fmt.Sprintf("%s://%s%s%s/%s?sslmode=disable",
543		scheme,
544		username,
545		host,
546		port,
547		dbName,
548	)
549
550	// Create the database
551	dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
552	if err != nil {
553		return err, nil
554	}
555
556	if _, err := dbx.Exec("CREATE DATABASE " + dbName); err != nil {
557		return err, nil
558	}
559
560	return nil, func() {
561		dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
562		if err != nil {
563			t.Fatal("failed to open database", dbName, err)
564		}
565
566		if _, err := dbx.Exec("DROP DATABASE " + dbName); err != nil {
567			t.Fatal("failed to drop database", dbName, err)
568		}
569	}
570}