script_test.go

  1package testscript
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"flag"
  8	"fmt"
  9	"io"
 10	"math/rand"
 11	"net"
 12	"net/http"
 13	"net/url"
 14	"os"
 15	"os/exec"
 16	"path/filepath"
 17	"runtime"
 18	"strings"
 19	"testing"
 20	"time"
 21
 22	"github.com/charmbracelet/keygen"
 23	"github.com/charmbracelet/soft-serve/pkg/config"
 24	"github.com/charmbracelet/soft-serve/pkg/db"
 25	"github.com/charmbracelet/soft-serve/pkg/test"
 26	"github.com/rogpeppe/go-internal/testscript"
 27	"github.com/spf13/cobra"
 28	"golang.org/x/crypto/ssh"
 29)
 30
 31var (
 32	update  = flag.Bool("update", false, "update script files")
 33	binPath string
 34)
 35
 36func TestMain(m *testing.M) {
 37	tmp, err := os.MkdirTemp("", "soft-serve*")
 38	if err != nil {
 39		fmt.Fprintf(os.Stderr, "failed to create temporary directory: %s", err)
 40		os.Exit(1)
 41	}
 42	defer os.RemoveAll(tmp)
 43
 44	binPath = filepath.Join(tmp, "soft")
 45	if runtime.GOOS == "windows" {
 46		binPath += ".exe"
 47	}
 48
 49	// Build the soft binary with -cover flag.
 50	cmd := exec.Command("go", "build", "-race", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft"))
 51	if err := cmd.Run(); err != nil {
 52		fmt.Fprintf(os.Stderr, "failed to build soft-serve binary: %s", err)
 53		os.Exit(1)
 54	}
 55
 56	// Run tests
 57	os.Exit(m.Run())
 58
 59	// Add binPath to PATH
 60	os.Setenv("PATH", fmt.Sprintf("%s%c%s", os.Getenv("PATH"), os.PathListSeparator, filepath.Dir(binPath)))
 61}
 62
 63func TestScript(t *testing.T) {
 64	flag.Parse()
 65
 66	mkkey := func(name string) (string, *keygen.SSHKeyPair) {
 67		path := filepath.Join(t.TempDir(), name)
 68		pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
 69		if err != nil {
 70			t.Fatal(err)
 71		}
 72		return path, pair
 73	}
 74
 75	key, admin1 := mkkey("admin1")
 76	_, admin2 := mkkey("admin2")
 77	_, user1 := mkkey("user1")
 78
 79	testscript.Run(t, testscript.Params{
 80		Dir:                 "./testdata/",
 81		UpdateScripts:       *update,
 82		RequireExplicitExec: true,
 83		Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
 84			"soft":          cmdSoft(admin1.Signer()),
 85			"usoft":         cmdSoft(user1.Signer()),
 86			"git":           cmdGit(key),
 87			"curl":          cmdCurl,
 88			"mkfile":        cmdMkfile,
 89			"envfile":       cmdEnvfile,
 90			"readfile":      cmdReadfile,
 91			"dos2unix":      cmdDos2Unix,
 92			"new-webhook":   cmdNewWebhook,
 93			"waitforserver": cmdWaitforserver,
 94			"stopserver":    cmdStopserver,
 95		},
 96		Setup: func(e *testscript.Env) error {
 97			// Add binPath to PATH
 98			e.Setenv("PATH", fmt.Sprintf("%s%c%s", filepath.Dir(binPath), os.PathListSeparator, e.Getenv("PATH")))
 99
100			data := t.TempDir()
101			sshPort := test.RandomPort()
102			sshListen := fmt.Sprintf("localhost:%d", sshPort)
103			gitPort := test.RandomPort()
104			gitListen := fmt.Sprintf("localhost:%d", gitPort)
105			httpPort := test.RandomPort()
106			httpListen := fmt.Sprintf("localhost:%d", httpPort)
107			statsPort := test.RandomPort()
108			statsListen := fmt.Sprintf("localhost:%d", statsPort)
109			serverName := "Test Soft Serve"
110
111			e.Setenv("DATA_PATH", data)
112			e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
113			e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
114			e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
115			e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
116			e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
117			e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
118			e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
119
120			// This is used to set up test specific configuration and http endpoints
121			e.Setenv("SOFT_SERVE_TESTRUN", "1")
122
123			// Soft Serve debug environment variables
124			for _, env := range []string{
125				"SOFT_SERVE_DEBUG",
126				"SOFT_SERVE_VERBOSE",
127			} {
128				if v, ok := os.LookupEnv(env); ok {
129					e.Setenv(env, v)
130				}
131			}
132
133			// TODO: test different configs
134			cfg := config.DefaultConfig()
135			cfg.DataPath = data
136			cfg.Name = serverName
137			cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
138			cfg.SSH.ListenAddr = sshListen
139			cfg.SSH.PublicURL = "ssh://" + sshListen
140			cfg.Git.ListenAddr = gitListen
141			cfg.HTTP.ListenAddr = httpListen
142			cfg.HTTP.PublicURL = "http://" + httpListen
143			cfg.Stats.ListenAddr = statsListen
144			cfg.LFS.Enabled = true
145			// cfg.LFS.SSHEnabled = true
146
147			// Parse os SOFT_SERVE environment variables
148			if err := cfg.ParseEnv(); err != nil {
149				return err
150			}
151
152			// Override the database data source if we're using postgres
153			// so we can create a temporary database for the tests.
154			if cfg.DB.Driver == "postgres" {
155				err, cleanup := setupPostgres(e.T(), cfg)
156				if err != nil {
157					return err
158				}
159				if cleanup != nil {
160					e.Defer(cleanup)
161				}
162			}
163
164			for _, env := range cfg.Environ() {
165				parts := strings.SplitN(env, "=", 2)
166				if len(parts) != 2 {
167					e.T().Fatal("invalid environment variable", env)
168				}
169				e.Setenv(parts[0], parts[1])
170			}
171
172			return nil
173		},
174	})
175}
176
177func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
178	return func(ts *testscript.TestScript, neg bool, args []string) {
179		cli, err := ssh.Dial(
180			"tcp",
181			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
182			&ssh.ClientConfig{
183				User:            "admin",
184				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
185				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
186			},
187		)
188		ts.Check(err)
189		defer cli.Close()
190
191		sess, err := cli.NewSession()
192		ts.Check(err)
193		defer sess.Close()
194
195		sess.Stdout = ts.Stdout()
196		sess.Stderr = ts.Stderr()
197
198		check(ts, sess.Run(strings.Join(args, " ")), neg)
199	}
200}
201
202// P.S. Windows sucks!
203func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
204	if neg {
205		ts.Fatalf("unsupported: ! dos2unix")
206	}
207	if len(args) < 1 {
208		ts.Fatalf("usage: dos2unix paths...")
209	}
210	for _, arg := range args {
211		filename := ts.MkAbs(arg)
212		data, err := os.ReadFile(filename)
213		if err != nil {
214			ts.Fatalf("%s: %v", filename, err)
215		}
216
217		// Replace all '\r\n' with '\n'.
218		data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
219
220		if err := os.WriteFile(filename, data, 0o644); err != nil {
221			ts.Fatalf("%s: %v", filename, err)
222		}
223	}
224}
225
226var sshConfig = `
227Host *
228  UserKnownHostsFile %q
229  StrictHostKeyChecking no
230  IdentityAgent none
231  IdentitiesOnly yes
232  ServerAliveInterval 60
233`
234
235func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
236	return func(ts *testscript.TestScript, neg bool, args []string) {
237		ts.Check(os.WriteFile(
238			ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
239			[]byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
240			0o600,
241		))
242		sshArgs := []string{
243			"-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
244			"-i", filepath.ToSlash(key),
245		}
246		ts.Setenv(
247			"GIT_SSH_COMMAND",
248			strings.Join(append([]string{"ssh"}, sshArgs...), " "),
249		)
250		// Disable git prompting for credentials.
251		ts.Setenv("GIT_TERMINAL_PROMPT", "0")
252		args = append([]string{
253			"-c", "user.email=john@example.com",
254			"-c", "user.name=John Doe",
255		}, args...)
256		check(ts, ts.Exec("git", args...), neg)
257	}
258}
259
260func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
261	if len(args) < 2 {
262		ts.Fatalf("usage: mkfile path content")
263	}
264	check(ts, os.WriteFile(
265		ts.MkAbs(args[0]),
266		[]byte(strings.Join(args[1:], " ")),
267		0o644,
268	), neg)
269}
270
271func check(ts *testscript.TestScript, err error, neg bool) {
272	if neg && err == nil {
273		ts.Fatalf("expected error, got nil")
274	}
275	if !neg {
276		ts.Check(err)
277	}
278}
279
280func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
281	ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
282}
283
284func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {
285	if len(args) < 1 {
286		ts.Fatalf("usage: envfile key=file...")
287	}
288
289	for _, arg := range args {
290		parts := strings.SplitN(arg, "=", 2)
291		if len(parts) != 2 {
292			ts.Fatalf("usage: envfile key=file...")
293		}
294		key := parts[0]
295		file := parts[1]
296		ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
297	}
298}
299
300func cmdNewWebhook(ts *testscript.TestScript, neg bool, args []string) {
301	type webhookSite struct {
302		UUID string `json:"uuid"`
303	}
304
305	if len(args) != 1 {
306		ts.Fatalf("usage: new-webhook <env-name>")
307	}
308
309	const whSite = "https://webhook.site"
310	req, err := http.NewRequest(http.MethodPost, whSite+"/token", nil)
311	check(ts, err, neg)
312
313	resp, err := http.DefaultClient.Do(req)
314	check(ts, err, neg)
315
316	defer resp.Body.Close()
317	var site webhookSite
318	check(ts, json.NewDecoder(resp.Body).Decode(&site), neg)
319
320	ts.Setenv(args[0], whSite+"/"+site.UUID)
321}
322
323func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {
324	var verbose bool
325	var headers []string
326	var data string
327	method := http.MethodGet
328
329	cmd := &cobra.Command{
330		Use:  "curl",
331		Args: cobra.MinimumNArgs(1),
332		RunE: func(cmd *cobra.Command, args []string) error {
333			url, err := url.Parse(args[0])
334			if err != nil {
335				return err
336			}
337
338			req, err := http.NewRequest(method, url.String(), nil)
339			if err != nil {
340				return err
341			}
342
343			if data != "" {
344				req.Body = io.NopCloser(strings.NewReader(data))
345			}
346
347			if verbose {
348				fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())
349			}
350
351			for _, header := range headers {
352				parts := strings.SplitN(header, ":", 2)
353				if len(parts) != 2 {
354					return fmt.Errorf("invalid header: %s", header)
355				}
356				req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
357			}
358
359			if userInfo := url.User; userInfo != nil {
360				password, _ := userInfo.Password()
361				req.SetBasicAuth(userInfo.Username(), password)
362			}
363
364			if verbose {
365				for key, values := range req.Header {
366					for _, value := range values {
367						fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)
368					}
369				}
370			}
371
372			resp, err := http.DefaultClient.Do(req)
373			if err != nil {
374				return err
375			}
376
377			if verbose {
378				fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)
379				for key, values := range resp.Header {
380					for _, value := range values {
381						fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)
382					}
383				}
384			}
385
386			defer resp.Body.Close()
387			buf, err := io.ReadAll(resp.Body)
388			if err != nil {
389				return err
390			}
391
392			cmd.Print(string(buf))
393
394			return nil
395		},
396	}
397
398	cmd.SetArgs(args)
399	cmd.SetOut(ts.Stdout())
400	cmd.SetErr(ts.Stderr())
401
402	cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")
403	cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")
404	cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")
405	cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")
406
407	check(ts, cmd.Execute(), neg)
408}
409
410func cmdWaitforserver(ts *testscript.TestScript, neg bool, args []string) {
411	// wait until the server is up
412	for {
413		conn, _ := net.DialTimeout(
414			"tcp",
415			net.JoinHostPort("localhost", fmt.Sprintf("%s", ts.Getenv("SSH_PORT"))),
416			time.Second,
417		)
418		if conn != nil {
419			conn.Close()
420			break
421		}
422	}
423}
424
425func cmdStopserver(ts *testscript.TestScript, neg bool, args []string) {
426	// stop the server
427	resp, err := http.DefaultClient.Head(fmt.Sprintf("%s/__stop", ts.Getenv("SOFT_SERVE_HTTP_PUBLIC_URL")))
428	check(ts, err, neg)
429	defer resp.Body.Close()
430	time.Sleep(time.Second * 2) // Allow some time for the server to stop
431}
432
433func setupPostgres(t testscript.T, cfg *config.Config) (error, func()) {
434	// Indicates postgres
435	// Create a disposable database
436	rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
437	dbName := fmt.Sprintf("softserve_test_%d", rnd.Int63())
438	dbDsn := cfg.DB.DataSource
439	if dbDsn == "" {
440		cfg.DB.DataSource = "postgres://postgres@localhost:5432/postgres?sslmode=disable"
441	}
442
443	dbUrl, err := url.Parse(cfg.DB.DataSource)
444	if err != nil {
445		return err, nil
446	}
447
448	scheme := dbUrl.Scheme
449	if scheme == "" {
450		scheme = "postgres"
451	}
452
453	host := dbUrl.Hostname()
454	if host == "" {
455		host = "localhost"
456	}
457
458	connInfo := fmt.Sprintf("host=%s sslmode=disable", host)
459	username := dbUrl.User.Username()
460	if username != "" {
461		connInfo += fmt.Sprintf(" user=%s", username)
462		password, ok := dbUrl.User.Password()
463		if ok {
464			username = fmt.Sprintf("%s:%s", username, password)
465			connInfo += fmt.Sprintf(" password=%s", password)
466		}
467		username = fmt.Sprintf("%s@", username)
468	} else {
469		connInfo += " user=postgres"
470		username = "postgres@"
471	}
472
473	port := dbUrl.Port()
474	if port != "" {
475		connInfo += fmt.Sprintf(" port=%s", port)
476		port = fmt.Sprintf(":%s", port)
477	}
478
479	cfg.DB.DataSource = fmt.Sprintf("%s://%s%s%s/%s?sslmode=disable",
480		scheme,
481		username,
482		host,
483		port,
484		dbName,
485	)
486
487	// Create the database
488	dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
489	if err != nil {
490		return err, nil
491	}
492
493	if _, err := dbx.Exec("CREATE DATABASE " + dbName); err != nil {
494		return err, nil
495	}
496
497	return nil, func() {
498		dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
499		if err != nil {
500			t.Fatal("failed to open database", dbName, err)
501		}
502
503		if _, err := dbx.Exec("DROP DATABASE " + dbName); err != nil {
504			t.Fatal("failed to drop database", dbName, err)
505		}
506	}
507}