1package testscript
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "flag"
8 "fmt"
9 "io"
10 "math/rand"
11 "net"
12 "net/http"
13 "net/url"
14 "os"
15 "os/exec"
16 "path/filepath"
17 "runtime"
18 "strconv"
19 "strings"
20 "testing"
21 "time"
22
23 "github.com/charmbracelet/keygen"
24 "github.com/charmbracelet/soft-serve/pkg/config"
25 "github.com/charmbracelet/soft-serve/pkg/db"
26 "github.com/charmbracelet/soft-serve/pkg/test"
27 "github.com/rogpeppe/go-internal/testscript"
28 "github.com/spf13/cobra"
29 "golang.org/x/crypto/ssh"
30)
31
32var (
33 update = flag.Bool("update", false, "update script files")
34 binPath string
35)
36
37func PrepareBuildCommand(binPath string) *exec.Cmd {
38 _, disableRaceSet := os.LookupEnv("SOFT_SERVE_DISABLE_RACE_CHECKS")
39 if disableRaceSet {
40 // don't add the -race flag
41 return exec.Command("go", "build", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft"))
42 }
43 return exec.Command("go", "build", "-race", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft"))
44}
45
46func TestMain(m *testing.M) {
47 tmp, err := os.MkdirTemp("", "soft-serve*")
48 if err != nil {
49 fmt.Fprintf(os.Stderr, "failed to create temporary directory: %s", err)
50 os.Exit(1)
51 }
52 defer os.RemoveAll(tmp)
53
54 binPath = filepath.Join(tmp, "soft")
55 if runtime.GOOS == "windows" {
56 binPath += ".exe"
57 }
58
59 // Build the soft binary with -cover flag.
60 cmd := PrepareBuildCommand(binPath)
61 if err := cmd.Run(); err != nil {
62 fmt.Fprintf(os.Stderr, "failed to build soft-serve binary: %s", err)
63 os.Exit(1)
64 }
65
66 // Run tests
67 os.Exit(m.Run())
68}
69
70func TestScript(t *testing.T) {
71 flag.Parse()
72
73 mkkey := func(name string) (string, *keygen.SSHKeyPair) {
74 path := filepath.Join(t.TempDir(), name)
75 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
76 if err != nil {
77 t.Fatal(err)
78 }
79 return path, pair
80 }
81
82 admin1Key, admin1 := mkkey("admin1")
83 _, admin2 := mkkey("admin2")
84 user1Key, user1 := mkkey("user1")
85
86 testscript.Run(t, testscript.Params{
87 Dir: "./testdata/",
88 UpdateScripts: *update,
89 RequireExplicitExec: true,
90 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
91 "soft": cmdSoft("admin", admin1.Signer()),
92 "usoft": cmdSoft("user1", user1.Signer()),
93 "git": cmdGit(admin1Key),
94 "ugit": cmdGit(user1Key),
95 "curl": cmdCurl,
96 "mkfile": cmdMkfile,
97 "envfile": cmdEnvfile,
98 "readfile": cmdReadfile,
99 "dos2unix": cmdDos2Unix,
100 "new-webhook": cmdNewWebhook,
101 "ensureserverrunning": cmdEnsureServerRunning,
102 "ensureservernotrunning": cmdEnsureServerNotRunning,
103 "stopserver": cmdStopserver,
104 "ui": cmdUI(admin1.Signer()),
105 "uui": cmdUI(user1.Signer()),
106 },
107 Setup: func(e *testscript.Env) error {
108 // Add binPath to PATH
109 e.Setenv("PATH", fmt.Sprintf("%s%c%s", filepath.Dir(binPath), os.PathListSeparator, e.Getenv("PATH")))
110
111 data := t.TempDir()
112 sshPort := test.RandomPort()
113 sshListen := fmt.Sprintf("localhost:%d", sshPort)
114 gitPort := test.RandomPort()
115 gitListen := fmt.Sprintf("localhost:%d", gitPort)
116 httpPort := test.RandomPort()
117 httpListen := fmt.Sprintf("localhost:%d", httpPort)
118 statsPort := test.RandomPort()
119 statsListen := fmt.Sprintf("localhost:%d", statsPort)
120 serverName := "Test Soft Serve"
121
122 e.Setenv("DATA_PATH", data)
123 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
124 e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
125 e.Setenv("STATS_PORT", fmt.Sprintf("%d", statsPort))
126 e.Setenv("GIT_PORT", fmt.Sprintf("%d", gitPort))
127 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
128 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
129 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
130 e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
131 e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
132
133 // This is used to set up test specific configuration and http endpoints
134 e.Setenv("SOFT_SERVE_TESTRUN", "1")
135
136 // This will disable the default lipgloss renderer colors
137 e.Setenv("SOFT_SERVE_NO_COLOR", "1")
138
139 // Soft Serve debug environment variables
140 for _, env := range []string{
141 "SOFT_SERVE_DEBUG",
142 "SOFT_SERVE_VERBOSE",
143 } {
144 if v, ok := os.LookupEnv(env); ok {
145 e.Setenv(env, v)
146 }
147 }
148
149 // TODO: test different configs
150 cfg := config.DefaultConfig()
151 cfg.DataPath = data
152 cfg.Name = serverName
153 cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
154 cfg.SSH.ListenAddr = sshListen
155 cfg.SSH.PublicURL = "ssh://" + sshListen
156 cfg.Git.ListenAddr = gitListen
157 cfg.HTTP.ListenAddr = httpListen
158 cfg.HTTP.PublicURL = "http://" + httpListen
159 cfg.Stats.ListenAddr = statsListen
160 cfg.LFS.Enabled = true
161
162 // Parse os SOFT_SERVE environment variables
163 if err := cfg.ParseEnv(); err != nil {
164 return err
165 }
166
167 // Override the database data source if we're using postgres
168 // so we can create a temporary database for the tests.
169 if cfg.DB.Driver == "postgres" {
170 err, cleanup := setupPostgres(e.T(), cfg)
171 if err != nil {
172 return err
173 }
174 if cleanup != nil {
175 e.Defer(cleanup)
176 }
177 }
178
179 for _, env := range cfg.Environ() {
180 parts := strings.SplitN(env, "=", 2)
181 if len(parts) != 2 {
182 e.T().Fatal("invalid environment variable", env)
183 }
184 e.Setenv(parts[0], parts[1])
185 }
186
187 return nil
188 },
189 })
190}
191
192func cmdSoft(user string, key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
193 return func(ts *testscript.TestScript, neg bool, args []string) {
194 cli, err := ssh.Dial(
195 "tcp",
196 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
197 &ssh.ClientConfig{
198 User: user,
199 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
200 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
201 },
202 )
203 ts.Check(err)
204 defer cli.Close()
205
206 sess, err := cli.NewSession()
207 ts.Check(err)
208 defer sess.Close()
209
210 sess.Stdout = ts.Stdout()
211 sess.Stderr = ts.Stderr()
212
213 check(ts, sess.Run(strings.Join(args, " ")), neg)
214 }
215}
216
217func cmdUI(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
218 return func(ts *testscript.TestScript, neg bool, args []string) {
219 if len(args) < 1 {
220 ts.Fatalf("usage: ui <quoted string input>")
221 return
222 }
223
224 cli, err := ssh.Dial(
225 "tcp",
226 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
227 &ssh.ClientConfig{
228 User: "git",
229 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
230 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
231 },
232 )
233 check(ts, err, neg)
234 defer cli.Close()
235
236 sess, err := cli.NewSession()
237 check(ts, err, neg)
238 defer sess.Close()
239
240 // XXX: this is a hack to make the UI tests work
241 // cmp command always complains about an extra newline
242 // in the output
243 defer ts.Stdout().Write([]byte("\n"))
244
245 sess.Stdout = ts.Stdout()
246 sess.Stderr = ts.Stderr()
247
248 stdin, err := sess.StdinPipe()
249 check(ts, err, neg)
250
251 err = sess.RequestPty("dumb", 40, 80, ssh.TerminalModes{})
252 check(ts, err, neg)
253 check(ts, sess.Start(""), neg)
254
255 in, err := strconv.Unquote(args[0])
256 check(ts, err, neg)
257 reader := strings.NewReader(in)
258 go func() {
259 defer stdin.Close()
260 for {
261 r, _, err := reader.ReadRune()
262 if err == io.EOF {
263 break
264 }
265 check(ts, err, neg)
266 stdin.Write([]byte(string(r))) //nolint: errcheck
267
268 // Wait for the UI to process the input
269 time.Sleep(100 * time.Millisecond)
270 }
271 }()
272
273 check(ts, sess.Wait(), neg)
274 }
275}
276
277func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
278 if neg {
279 ts.Fatalf("unsupported: ! dos2unix")
280 }
281 if len(args) < 1 {
282 ts.Fatalf("usage: dos2unix paths...")
283 }
284 for _, arg := range args {
285 filename := ts.MkAbs(arg)
286 data, err := os.ReadFile(filename)
287 if err != nil {
288 ts.Fatalf("%s: %v", filename, err)
289 }
290
291 // Replace all '\r\n' with '\n'.
292 data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
293
294 if err := os.WriteFile(filename, data, 0o644); err != nil {
295 ts.Fatalf("%s: %v", filename, err)
296 }
297 }
298}
299
300var sshConfig = `
301Host *
302 UserKnownHostsFile %q
303 StrictHostKeyChecking no
304 IdentityAgent none
305 IdentitiesOnly yes
306 ServerAliveInterval 60
307`
308
309func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
310 return func(ts *testscript.TestScript, neg bool, args []string) {
311 ts.Check(os.WriteFile(
312 ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
313 []byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
314 0o600,
315 ))
316 sshArgs := []string{
317 "-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
318 "-i", filepath.ToSlash(key),
319 }
320 ts.Setenv(
321 "GIT_SSH_COMMAND",
322 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
323 )
324 // Disable git prompting for credentials.
325 ts.Setenv("GIT_TERMINAL_PROMPT", "0")
326 args = append([]string{
327 "-c", "user.email=john@example.com",
328 "-c", "user.name=John Doe",
329 }, args...)
330 check(ts, ts.Exec("git", args...), neg)
331 }
332}
333
334func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
335 if len(args) < 2 {
336 ts.Fatalf("usage: mkfile path content")
337 }
338 check(ts, os.WriteFile(
339 ts.MkAbs(args[0]),
340 []byte(strings.Join(args[1:], " ")),
341 0o644,
342 ), neg)
343}
344
345func check(ts *testscript.TestScript, err error, neg bool) {
346 if neg && err == nil {
347 ts.Fatalf("expected error, got nil")
348 }
349 if !neg {
350 ts.Check(err)
351 }
352}
353
354func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
355 ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
356}
357
358func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {
359 if len(args) < 1 {
360 ts.Fatalf("usage: envfile key=file...")
361 }
362
363 for _, arg := range args {
364 parts := strings.SplitN(arg, "=", 2)
365 if len(parts) != 2 {
366 ts.Fatalf("usage: envfile key=file...")
367 }
368 key := parts[0]
369 file := parts[1]
370 ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
371 }
372}
373
374func cmdNewWebhook(ts *testscript.TestScript, neg bool, args []string) {
375 type webhookSite struct {
376 UUID string `json:"uuid"`
377 }
378
379 if len(args) != 1 {
380 ts.Fatalf("usage: new-webhook <env-name>")
381 }
382
383 const whSite = "https://webhook.site"
384 req, err := http.NewRequest(http.MethodPost, whSite+"/token", nil)
385 check(ts, err, neg)
386
387 resp, err := http.DefaultClient.Do(req)
388 check(ts, err, neg)
389
390 defer resp.Body.Close()
391 var site webhookSite
392 check(ts, json.NewDecoder(resp.Body).Decode(&site), neg)
393
394 ts.Setenv(args[0], whSite+"/"+site.UUID)
395}
396
397func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {
398 var verbose bool
399 var headers []string
400 var data string
401 method := http.MethodGet
402
403 cmd := &cobra.Command{
404 Use: "curl",
405 Args: cobra.MinimumNArgs(1),
406 RunE: func(cmd *cobra.Command, args []string) error {
407 url, err := url.Parse(args[0])
408 if err != nil {
409 return err
410 }
411
412 req, err := http.NewRequest(method, url.String(), nil)
413 if err != nil {
414 return err
415 }
416
417 if data != "" {
418 req.Body = io.NopCloser(strings.NewReader(data))
419 }
420
421 if verbose {
422 fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())
423 }
424
425 for _, header := range headers {
426 parts := strings.SplitN(header, ":", 2)
427 if len(parts) != 2 {
428 return fmt.Errorf("invalid header: %s", header)
429 }
430 req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
431 }
432
433 if userInfo := url.User; userInfo != nil {
434 password, _ := userInfo.Password()
435 req.SetBasicAuth(userInfo.Username(), password)
436 }
437
438 if verbose {
439 for key, values := range req.Header {
440 for _, value := range values {
441 fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)
442 }
443 }
444 }
445
446 resp, err := http.DefaultClient.Do(req)
447 if err != nil {
448 return err
449 }
450
451 if verbose {
452 fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)
453 for key, values := range resp.Header {
454 for _, value := range values {
455 fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)
456 }
457 }
458 }
459
460 defer resp.Body.Close()
461 buf, err := io.ReadAll(resp.Body)
462 if err != nil {
463 return err
464 }
465
466 cmd.Print(string(buf))
467
468 return nil
469 },
470 }
471
472 cmd.SetArgs(args)
473 cmd.SetOut(ts.Stdout())
474 cmd.SetErr(ts.Stderr())
475
476 cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")
477 cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")
478 cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")
479 cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")
480
481 check(ts, cmd.Execute(), neg)
482}
483
484func cmdEnsureServerRunning(ts *testscript.TestScript, neg bool, args []string) {
485 if len(args) < 1 {
486 ts.Fatalf("Must supply a TCP port of one of the services to connect to. " +
487 "These are set as env vars as they are randomized. " +
488 "Example usage: \"cmdensureserverrunning SSH_PORT\"\n" +
489 "Valid values for the env var: SSH_PORT|HTTP_PORT|GIT_PORT|STATS_PORT")
490 }
491
492 port := ts.Getenv(args[0])
493
494 // verify that the server is up
495 addr := net.JoinHostPort("localhost", port)
496 for {
497 conn, _ := net.DialTimeout(
498 "tcp",
499 addr,
500 time.Second,
501 )
502 if conn != nil {
503 ts.Logf("Server is running on port: %s", port)
504 conn.Close()
505 break
506 }
507 }
508}
509
510func cmdEnsureServerNotRunning(ts *testscript.TestScript, neg bool, args []string) {
511 if len(args) < 1 {
512 ts.Fatalf("Must supply a TCP port of one of the services to connect to. " +
513 "These are set as env vars as they are randomized. " +
514 "Example usage: \"cmdensureservernotrunning SSH_PORT\"\n" +
515 "Valid values for the env var: SSH_PORT|HTTP_PORT|GIT_PORT|STATS_PORT")
516 }
517
518 port := ts.Getenv(args[0])
519
520 // verify that the server is not up
521 addr := net.JoinHostPort("localhost", port)
522 for {
523 conn, _ := net.DialTimeout(
524 "tcp",
525 addr,
526 time.Second,
527 )
528 if conn != nil {
529 ts.Fatalf("server is running on port %s while it should not be running", port)
530 conn.Close()
531 }
532 break
533 }
534}
535
536func cmdStopserver(ts *testscript.TestScript, neg bool, args []string) {
537 // stop the server
538 resp, err := http.DefaultClient.Head(fmt.Sprintf("%s/__stop", ts.Getenv("SOFT_SERVE_HTTP_PUBLIC_URL")))
539 check(ts, err, neg)
540 resp.Body.Close()
541 time.Sleep(time.Second * 2) // Allow some time for the server to stop
542}
543
544func setupPostgres(t testscript.T, cfg *config.Config) (error, func()) {
545 // Indicates postgres
546 // Create a disposable database
547 rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
548 dbName := fmt.Sprintf("softserve_test_%d", rnd.Int63())
549 dbDsn := cfg.DB.DataSource
550 if dbDsn == "" {
551 cfg.DB.DataSource = "postgres://postgres@localhost:5432/postgres?sslmode=disable"
552 }
553
554 dbUrl, err := url.Parse(cfg.DB.DataSource)
555 if err != nil {
556 return err, nil
557 }
558
559 scheme := dbUrl.Scheme
560 if scheme == "" {
561 scheme = "postgres"
562 }
563
564 host := dbUrl.Hostname()
565 if host == "" {
566 host = "localhost"
567 }
568
569 connInfo := fmt.Sprintf("host=%s sslmode=disable", host)
570 username := dbUrl.User.Username()
571 if username != "" {
572 connInfo += fmt.Sprintf(" user=%s", username)
573 password, ok := dbUrl.User.Password()
574 if ok {
575 username = fmt.Sprintf("%s:%s", username, password)
576 connInfo += fmt.Sprintf(" password=%s", password)
577 }
578 username = fmt.Sprintf("%s@", username)
579 } else {
580 connInfo += " user=postgres"
581 username = "postgres@"
582 }
583
584 port := dbUrl.Port()
585 if port != "" {
586 connInfo += fmt.Sprintf(" port=%s", port)
587 port = fmt.Sprintf(":%s", port)
588 }
589
590 cfg.DB.DataSource = fmt.Sprintf("%s://%s%s%s/%s?sslmode=disable",
591 scheme,
592 username,
593 host,
594 port,
595 dbName,
596 )
597
598 // Create the database
599 dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
600 if err != nil {
601 return err, nil
602 }
603
604 if _, err := dbx.Exec("CREATE DATABASE " + dbName); err != nil {
605 return err, nil
606 }
607
608 return nil, func() {
609 dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
610 if err != nil {
611 t.Fatal("failed to open database", dbName, err)
612 }
613
614 if _, err := dbx.Exec("DROP DATABASE " + dbName); err != nil {
615 t.Fatal("failed to drop database", dbName, err)
616 }
617 }
618}