1package testscript
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "flag"
8 "fmt"
9 "io"
10 "math/rand"
11 "net"
12 "net/http"
13 "net/url"
14 "os"
15 "os/exec"
16 "path/filepath"
17 "runtime"
18 "strconv"
19 "strings"
20 "testing"
21 "time"
22
23 "github.com/charmbracelet/keygen"
24 "github.com/charmbracelet/soft-serve/pkg/config"
25 "github.com/charmbracelet/soft-serve/pkg/db"
26 "github.com/charmbracelet/soft-serve/pkg/test"
27 "github.com/rogpeppe/go-internal/testscript"
28 "github.com/spf13/cobra"
29 "golang.org/x/crypto/ssh"
30)
31
32var (
33 update = flag.Bool("update", false, "update script files")
34 binPath string
35)
36
37func PrepareBuildCommand(binPath string) *exec.Cmd {
38 _, disableRaceSet := os.LookupEnv("SOFT_SERVE_DISABLE_RACE_CHECKS")
39 if disableRaceSet {
40 // don't add the -race flag
41 return exec.Command("go", "build", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft"))
42 }
43 return exec.Command("go", "build", "-race", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft"))
44}
45
46func TestMain(m *testing.M) {
47 tmp, err := os.MkdirTemp("", "soft-serve*")
48 if err != nil {
49 fmt.Fprintf(os.Stderr, "failed to create temporary directory: %s", err)
50 os.Exit(1)
51 }
52 defer os.RemoveAll(tmp)
53
54 binPath = filepath.Join(tmp, "soft")
55 if runtime.GOOS == "windows" {
56 binPath += ".exe"
57 }
58
59 // Build the soft binary with -cover flag.
60 cmd := PrepareBuildCommand(binPath)
61 if err := cmd.Run(); err != nil {
62 fmt.Fprintf(os.Stderr, "failed to build soft-serve binary: %s", err)
63 os.Exit(1)
64 }
65
66 // Run tests
67 os.Exit(m.Run())
68}
69
70func TestScript(t *testing.T) {
71 flag.Parse()
72
73 mkkey := func(name string) (string, *keygen.SSHKeyPair) {
74 path := filepath.Join(t.TempDir(), name)
75 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
76 if err != nil {
77 t.Fatal(err)
78 }
79 return path, pair
80 }
81
82 admin1Key, admin1 := mkkey("admin1")
83 _, admin2 := mkkey("admin2")
84 user1Key, user1 := mkkey("user1")
85
86 testscript.Run(t, testscript.Params{
87 Dir: "./testdata/",
88 UpdateScripts: *update,
89 RequireExplicitExec: true,
90 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
91 "soft": cmdSoft("admin", admin1.Signer()),
92 "usoft": cmdSoft("user1", user1.Signer()),
93 "git": cmdGit(admin1Key),
94 "ugit": cmdGit(user1Key),
95 "curl": cmdCurl,
96 "mkfile": cmdMkfile,
97 "envfile": cmdEnvfile,
98 "readfile": cmdReadfile,
99 "dos2unix": cmdDos2Unix,
100 "new-webhook": cmdNewWebhook,
101 "ensureserverrunning": cmdEnsureServerRunning,
102 "ensureservernotrunning": cmdEnsureServerNotRunning,
103 "stopserver": cmdStopserver,
104 "ui": cmdUI(admin1.Signer()),
105 "uui": cmdUI(user1.Signer()),
106 },
107 Setup: func(e *testscript.Env) error {
108 // Add binPath to PATH
109 e.Setenv("PATH", fmt.Sprintf("%s%c%s", filepath.Dir(binPath), os.PathListSeparator, e.Getenv("PATH")))
110
111 data := t.TempDir()
112 sshPort := test.RandomPort()
113 sshListen := fmt.Sprintf("localhost:%d", sshPort)
114 gitPort := test.RandomPort()
115 gitListen := fmt.Sprintf("localhost:%d", gitPort)
116 httpPort := test.RandomPort()
117 httpListen := fmt.Sprintf("localhost:%d", httpPort)
118 statsPort := test.RandomPort()
119 statsListen := fmt.Sprintf("localhost:%d", statsPort)
120 serverName := "Test Soft Serve"
121
122 e.Setenv("DATA_PATH", data)
123 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
124 e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
125 e.Setenv("STATS_PORT", fmt.Sprintf("%d", statsPort))
126 e.Setenv("GIT_PORT", fmt.Sprintf("%d", gitPort))
127 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
128 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
129 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
130 e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
131 e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
132
133 // This is used to set up test specific configuration and http endpoints
134 e.Setenv("SOFT_SERVE_TESTRUN", "1")
135
136 // This will disable the default lipgloss renderer colors
137 e.Setenv("SOFT_SERVE_NO_COLOR", "1")
138
139 // Soft Serve debug environment variables
140 for _, env := range []string{
141 "SOFT_SERVE_DEBUG",
142 "SOFT_SERVE_VERBOSE",
143 } {
144 if v, ok := os.LookupEnv(env); ok {
145 e.Setenv(env, v)
146 }
147 }
148
149 // TODO: test different configs
150 cfg := config.DefaultConfig()
151 cfg.DataPath = data
152 cfg.Name = serverName
153 cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
154 cfg.SSH.ListenAddr = sshListen
155 cfg.SSH.PublicURL = "ssh://" + sshListen
156 cfg.Git.ListenAddr = gitListen
157 cfg.HTTP.ListenAddr = httpListen
158 cfg.HTTP.PublicURL = "http://" + httpListen
159 cfg.Stats.ListenAddr = statsListen
160 cfg.LFS.Enabled = true
161
162 // Parse os SOFT_SERVE environment variables
163 if err := cfg.ParseEnv(); err != nil {
164 return err
165 }
166
167 // Override the database data source if we're using postgres
168 // so we can create a temporary database for the tests.
169 if cfg.DB.Driver == "postgres" {
170 err, cleanup := setupPostgres(e.T(), cfg)
171 if err != nil {
172 return err
173 }
174 if cleanup != nil {
175 e.Defer(cleanup)
176 }
177 }
178
179 for _, env := range cfg.Environ() {
180 parts := strings.SplitN(env, "=", 2)
181 if len(parts) != 2 {
182 e.T().Fatal("invalid environment variable", env)
183 }
184 e.Setenv(parts[0], parts[1])
185 }
186
187 return nil
188 },
189 })
190}
191
192func cmdSoft(user string, key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
193 return func(ts *testscript.TestScript, neg bool, args []string) {
194 cli, err := ssh.Dial(
195 "tcp",
196 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
197 &ssh.ClientConfig{
198 User: user,
199 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
200 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
201 },
202 )
203 ts.Check(err)
204 defer cli.Close()
205
206 sess, err := cli.NewSession()
207 ts.Check(err)
208 defer sess.Close()
209
210 sess.Stdout = ts.Stdout()
211 sess.Stderr = ts.Stderr()
212
213 check(ts, sess.Run(strings.Join(args, " ")), neg)
214 }
215}
216
217func cmdUI(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
218 return func(ts *testscript.TestScript, neg bool, args []string) {
219 if len(args) < 1 {
220 ts.Fatalf("usage: ui <quoted string input>")
221 return
222 }
223
224 cli, err := ssh.Dial(
225 "tcp",
226 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
227 &ssh.ClientConfig{
228 User: "git",
229 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
230 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
231 },
232 )
233 check(ts, err, neg)
234 defer cli.Close()
235
236 sess, err := cli.NewSession()
237 check(ts, err, neg)
238 defer sess.Close()
239
240 // XXX: this is a hack to make the UI tests work
241 // cmp command always complains about an extra newline
242 // in the output
243 defer ts.Stdout().Write([]byte("\n"))
244
245 sess.Stdout = ts.Stdout()
246 sess.Stderr = ts.Stderr()
247
248 stdin, err := sess.StdinPipe()
249 check(ts, err, neg)
250
251 err = sess.RequestPty("dumb", 40, 80, ssh.TerminalModes{})
252 check(ts, err, neg)
253 check(ts, sess.Start(""), neg)
254
255 in, err := strconv.Unquote(args[0])
256 check(ts, err, neg)
257 reader := strings.NewReader(in)
258 go func() {
259 defer stdin.Close()
260 for {
261 r, _, err := reader.ReadRune()
262 if err == io.EOF {
263 break
264 }
265 check(ts, err, neg)
266 stdin.Write([]byte(string(r))) // nolint: errcheck
267
268 // Wait for the UI to process the input
269 time.Sleep(100 * time.Millisecond)
270 }
271 }()
272
273 check(ts, sess.Wait(), neg)
274 }
275}
276
277// P.S. Windows sucks!
278func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
279 if neg {
280 ts.Fatalf("unsupported: ! dos2unix")
281 }
282 if len(args) < 1 {
283 ts.Fatalf("usage: dos2unix paths...")
284 }
285 for _, arg := range args {
286 filename := ts.MkAbs(arg)
287 data, err := os.ReadFile(filename)
288 if err != nil {
289 ts.Fatalf("%s: %v", filename, err)
290 }
291
292 // Replace all '\r\n' with '\n'.
293 data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
294
295 if err := os.WriteFile(filename, data, 0o644); err != nil {
296 ts.Fatalf("%s: %v", filename, err)
297 }
298 }
299}
300
301var sshConfig = `
302Host *
303 UserKnownHostsFile %q
304 StrictHostKeyChecking no
305 IdentityAgent none
306 IdentitiesOnly yes
307 ServerAliveInterval 60
308`
309
310func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
311 return func(ts *testscript.TestScript, neg bool, args []string) {
312 ts.Check(os.WriteFile(
313 ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
314 []byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
315 0o600,
316 ))
317 sshArgs := []string{
318 "-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
319 "-i", filepath.ToSlash(key),
320 }
321 ts.Setenv(
322 "GIT_SSH_COMMAND",
323 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
324 )
325 // Disable git prompting for credentials.
326 ts.Setenv("GIT_TERMINAL_PROMPT", "0")
327 args = append([]string{
328 "-c", "user.email=john@example.com",
329 "-c", "user.name=John Doe",
330 }, args...)
331 check(ts, ts.Exec("git", args...), neg)
332 }
333}
334
335func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
336 if len(args) < 2 {
337 ts.Fatalf("usage: mkfile path content")
338 }
339 check(ts, os.WriteFile(
340 ts.MkAbs(args[0]),
341 []byte(strings.Join(args[1:], " ")),
342 0o644,
343 ), neg)
344}
345
346func check(ts *testscript.TestScript, err error, neg bool) {
347 if neg && err == nil {
348 ts.Fatalf("expected error, got nil")
349 }
350 if !neg {
351 ts.Check(err)
352 }
353}
354
355func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
356 ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
357}
358
359func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {
360 if len(args) < 1 {
361 ts.Fatalf("usage: envfile key=file...")
362 }
363
364 for _, arg := range args {
365 parts := strings.SplitN(arg, "=", 2)
366 if len(parts) != 2 {
367 ts.Fatalf("usage: envfile key=file...")
368 }
369 key := parts[0]
370 file := parts[1]
371 ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
372 }
373}
374
375func cmdNewWebhook(ts *testscript.TestScript, neg bool, args []string) {
376 type webhookSite struct {
377 UUID string `json:"uuid"`
378 }
379
380 if len(args) != 1 {
381 ts.Fatalf("usage: new-webhook <env-name>")
382 }
383
384 const whSite = "https://webhook.site"
385 req, err := http.NewRequest(http.MethodPost, whSite+"/token", nil)
386 check(ts, err, neg)
387
388 resp, err := http.DefaultClient.Do(req)
389 check(ts, err, neg)
390
391 defer resp.Body.Close()
392 var site webhookSite
393 check(ts, json.NewDecoder(resp.Body).Decode(&site), neg)
394
395 ts.Setenv(args[0], whSite+"/"+site.UUID)
396}
397
398func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {
399 var verbose bool
400 var headers []string
401 var data string
402 method := http.MethodGet
403
404 cmd := &cobra.Command{
405 Use: "curl",
406 Args: cobra.MinimumNArgs(1),
407 RunE: func(cmd *cobra.Command, args []string) error {
408 url, err := url.Parse(args[0])
409 if err != nil {
410 return err
411 }
412
413 req, err := http.NewRequest(method, url.String(), nil)
414 if err != nil {
415 return err
416 }
417
418 if data != "" {
419 req.Body = io.NopCloser(strings.NewReader(data))
420 }
421
422 if verbose {
423 fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())
424 }
425
426 for _, header := range headers {
427 parts := strings.SplitN(header, ":", 2)
428 if len(parts) != 2 {
429 return fmt.Errorf("invalid header: %s", header)
430 }
431 req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
432 }
433
434 if userInfo := url.User; userInfo != nil {
435 password, _ := userInfo.Password()
436 req.SetBasicAuth(userInfo.Username(), password)
437 }
438
439 if verbose {
440 for key, values := range req.Header {
441 for _, value := range values {
442 fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)
443 }
444 }
445 }
446
447 resp, err := http.DefaultClient.Do(req)
448 if err != nil {
449 return err
450 }
451
452 if verbose {
453 fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)
454 for key, values := range resp.Header {
455 for _, value := range values {
456 fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)
457 }
458 }
459 }
460
461 defer resp.Body.Close()
462 buf, err := io.ReadAll(resp.Body)
463 if err != nil {
464 return err
465 }
466
467 cmd.Print(string(buf))
468
469 return nil
470 },
471 }
472
473 cmd.SetArgs(args)
474 cmd.SetOut(ts.Stdout())
475 cmd.SetErr(ts.Stderr())
476
477 cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")
478 cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")
479 cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")
480 cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")
481
482 check(ts, cmd.Execute(), neg)
483}
484
485func cmdEnsureServerRunning(ts *testscript.TestScript, neg bool, args []string) {
486 if len(args) < 1 {
487 ts.Fatalf("Must supply a TCP port of one of the services to connect to. " +
488 "These are set as env vars as they are randomized. " +
489 "Example usage: \"cmdensureserverrunning SSH_PORT\"\n" +
490 "Valid values for the env var: SSH_PORT|HTTP_PORT|GIT_PORT|STATS_PORT")
491 }
492
493 port := ts.Getenv(args[0])
494
495 // verify that the server is up
496 addr := net.JoinHostPort("localhost", port)
497 for {
498 conn, _ := net.DialTimeout(
499 "tcp",
500 addr,
501 time.Second,
502 )
503 if conn != nil {
504 ts.Logf("Server is running on port: %s", port)
505 conn.Close()
506 break
507 }
508 }
509}
510
511func cmdEnsureServerNotRunning(ts *testscript.TestScript, neg bool, args []string) {
512 if len(args) < 1 {
513 ts.Fatalf("Must supply a TCP port of one of the services to connect to. " +
514 "These are set as env vars as they are randomized. " +
515 "Example usage: \"cmdensureservernotrunning SSH_PORT\"\n" +
516 "Valid values for the env var: SSH_PORT|HTTP_PORT|GIT_PORT|STATS_PORT")
517 }
518
519 port := ts.Getenv(args[0])
520
521 // verify that the server is not up
522 addr := net.JoinHostPort("localhost", port)
523 for {
524 conn, _ := net.DialTimeout(
525 "tcp",
526 addr,
527 time.Second,
528 )
529 if conn != nil {
530 ts.Fatalf("server is running on port %s while it should not be running", port)
531 conn.Close()
532 }
533 break
534 }
535}
536
537func cmdStopserver(ts *testscript.TestScript, neg bool, args []string) {
538 // stop the server
539 resp, err := http.DefaultClient.Head(fmt.Sprintf("%s/__stop", ts.Getenv("SOFT_SERVE_HTTP_PUBLIC_URL")))
540 check(ts, err, neg)
541 resp.Body.Close()
542 time.Sleep(time.Second * 2) // Allow some time for the server to stop
543}
544
545func setupPostgres(t testscript.T, cfg *config.Config) (error, func()) {
546 // Indicates postgres
547 // Create a disposable database
548 rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
549 dbName := fmt.Sprintf("softserve_test_%d", rnd.Int63())
550 dbDsn := cfg.DB.DataSource
551 if dbDsn == "" {
552 cfg.DB.DataSource = "postgres://postgres@localhost:5432/postgres?sslmode=disable"
553 }
554
555 dbUrl, err := url.Parse(cfg.DB.DataSource)
556 if err != nil {
557 return err, nil
558 }
559
560 scheme := dbUrl.Scheme
561 if scheme == "" {
562 scheme = "postgres"
563 }
564
565 host := dbUrl.Hostname()
566 if host == "" {
567 host = "localhost"
568 }
569
570 connInfo := fmt.Sprintf("host=%s sslmode=disable", host)
571 username := dbUrl.User.Username()
572 if username != "" {
573 connInfo += fmt.Sprintf(" user=%s", username)
574 password, ok := dbUrl.User.Password()
575 if ok {
576 username = fmt.Sprintf("%s:%s", username, password)
577 connInfo += fmt.Sprintf(" password=%s", password)
578 }
579 username = fmt.Sprintf("%s@", username)
580 } else {
581 connInfo += " user=postgres"
582 username = "postgres@"
583 }
584
585 port := dbUrl.Port()
586 if port != "" {
587 connInfo += fmt.Sprintf(" port=%s", port)
588 port = fmt.Sprintf(":%s", port)
589 }
590
591 cfg.DB.DataSource = fmt.Sprintf("%s://%s%s%s/%s?sslmode=disable",
592 scheme,
593 username,
594 host,
595 port,
596 dbName,
597 )
598
599 // Create the database
600 dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
601 if err != nil {
602 return err, nil
603 }
604
605 if _, err := dbx.Exec("CREATE DATABASE " + dbName); err != nil {
606 return err, nil
607 }
608
609 return nil, func() {
610 dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
611 if err != nil {
612 t.Fatal("failed to open database", dbName, err)
613 }
614
615 if _, err := dbx.Exec("DROP DATABASE " + dbName); err != nil {
616 t.Fatal("failed to drop database", dbName, err)
617 }
618 }
619}