1package testscript
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "flag"
8 "fmt"
9 "io"
10 "math/rand"
11 "net"
12 "net/http"
13 "net/url"
14 "os"
15 "os/exec"
16 "path/filepath"
17 "runtime"
18 "strconv"
19 "strings"
20 "testing"
21 "time"
22
23 "github.com/charmbracelet/keygen"
24 "github.com/charmbracelet/soft-serve/pkg/config"
25 "github.com/charmbracelet/soft-serve/pkg/db"
26 "github.com/charmbracelet/soft-serve/pkg/test"
27 "github.com/rogpeppe/go-internal/testscript"
28 "github.com/spf13/cobra"
29 "golang.org/x/crypto/ssh"
30)
31
32var (
33 update = flag.Bool("update", false, "update script files")
34 binPath string
35)
36
37func PrepareBuildCommand(binPath string) *exec.Cmd {
38 _, disableRaceSet := os.LookupEnv("SOFT_SERVE_DISABLE_RACE_CHECKS")
39 if disableRaceSet {
40 // don't add the -race flag
41 return exec.Command("go", "build", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft")) //nolint:noctx
42 }
43 return exec.Command("go", "build", "-race", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft")) //nolint:noctx
44}
45
46func TestMain(m *testing.M) {
47 tmp, err := os.MkdirTemp("", "soft-serve*")
48 if err != nil {
49 fmt.Fprintf(os.Stderr, "failed to create temporary directory: %s", err)
50 os.Exit(1)
51 }
52 defer os.RemoveAll(tmp)
53
54 binPath = filepath.Join(tmp, "soft")
55 if runtime.GOOS == "windows" {
56 binPath += ".exe"
57 }
58
59 // Build the soft binary with -cover flag.
60 cmd := PrepareBuildCommand(binPath)
61 if err := cmd.Run(); err != nil {
62 fmt.Fprintf(os.Stderr, "failed to build soft-serve binary: %s", err)
63 os.Exit(1)
64 }
65
66 // Run tests
67 os.Exit(m.Run())
68}
69
70func TestScript(t *testing.T) {
71 flag.Parse()
72
73 mkkey := func(name string) (string, *keygen.SSHKeyPair) {
74 path := filepath.Join(t.TempDir(), name)
75 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
76 if err != nil {
77 t.Fatal(err)
78 }
79 return path, pair
80 }
81
82 admin1Key, admin1 := mkkey("admin1")
83 _, admin2 := mkkey("admin2")
84 user1Key, user1 := mkkey("user1")
85 attackerKey, attacker := mkkey("attacker")
86 attackerSigner := &maliciousSigner{
87 publicKey: admin1.PublicKey(),
88 }
89
90 testscript.Run(t, testscript.Params{
91 Dir: "./testdata/",
92 UpdateScripts: *update,
93 RequireExplicitExec: true,
94 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
95 "soft": cmdSoft("admin", admin1.Signer()),
96 "usoft": cmdSoft("user1", user1.Signer()),
97 "attacksoft": cmdSoft("attacker", attackerSigner, attacker.Signer()),
98 "git": cmdGit(admin1Key),
99 "ugit": cmdGit(user1Key),
100 "agit": cmdGit(attackerKey),
101 "curl": cmdCurl,
102 "mkfile": cmdMkfile,
103 "envfile": cmdEnvfile,
104 "readfile": cmdReadfile,
105 "dos2unix": cmdDos2Unix,
106 "new-webhook": cmdNewWebhook,
107 "ensureserverrunning": cmdEnsureServerRunning,
108 "ensureservernotrunning": cmdEnsureServerNotRunning,
109 "stopserver": cmdStopserver,
110 "ui": cmdUI(admin1.Signer()),
111 "uui": cmdUI(user1.Signer()),
112 },
113 Setup: func(e *testscript.Env) error {
114 // Add binPath to PATH
115 e.Setenv("PATH", fmt.Sprintf("%s%c%s", filepath.Dir(binPath), os.PathListSeparator, e.Getenv("PATH")))
116
117 data := t.TempDir()
118 sshPort := test.RandomPort()
119 sshListen := fmt.Sprintf("localhost:%d", sshPort)
120 gitPort := test.RandomPort()
121 gitListen := fmt.Sprintf("localhost:%d", gitPort)
122 httpPort := test.RandomPort()
123 httpListen := fmt.Sprintf("localhost:%d", httpPort)
124 statsPort := test.RandomPort()
125 statsListen := fmt.Sprintf("localhost:%d", statsPort)
126 serverName := "Test Soft Serve"
127
128 e.Setenv("DATA_PATH", data)
129 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
130 e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
131 e.Setenv("STATS_PORT", fmt.Sprintf("%d", statsPort))
132 e.Setenv("GIT_PORT", fmt.Sprintf("%d", gitPort))
133 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
134 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
135 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
136 e.Setenv("ATTACKER_AUTHORIZED_KEY", attacker.AuthorizedKey())
137 e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
138 e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
139
140 // This is used to set up test specific configuration and http endpoints
141 e.Setenv("SOFT_SERVE_TESTRUN", "1")
142
143 // This will disable the default lipgloss renderer colors
144 e.Setenv("SOFT_SERVE_NO_COLOR", "1")
145
146 // Soft Serve debug environment variables
147 for _, env := range []string{
148 "SOFT_SERVE_DEBUG",
149 "SOFT_SERVE_VERBOSE",
150 } {
151 if v, ok := os.LookupEnv(env); ok {
152 e.Setenv(env, v)
153 }
154 }
155
156 // TODO: test different configs
157 cfg := config.DefaultConfig()
158 cfg.DataPath = data
159 cfg.Name = serverName
160 cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
161 cfg.SSH.ListenAddr = sshListen
162 cfg.SSH.PublicURL = "ssh://" + sshListen
163 cfg.Git.ListenAddr = gitListen
164 cfg.HTTP.ListenAddr = httpListen
165 cfg.HTTP.PublicURL = "http://" + httpListen
166 cfg.Stats.ListenAddr = statsListen
167 cfg.LFS.Enabled = true
168
169 // Parse os SOFT_SERVE environment variables
170 if err := cfg.ParseEnv(); err != nil {
171 return err
172 }
173
174 // Override the database data source if we're using postgres
175 // so we can create a temporary database for the tests.
176 if cfg.DB.Driver == "postgres" {
177 cleanup, err := setupPostgres(e.T(), cfg)
178 if err != nil {
179 return err
180 }
181 if cleanup != nil {
182 e.Defer(cleanup)
183 }
184 }
185
186 for _, env := range cfg.Environ() {
187 parts := strings.SplitN(env, "=", 2)
188 if len(parts) != 2 {
189 e.T().Fatal("invalid environment variable", env)
190 }
191 e.Setenv(parts[0], parts[1])
192 }
193
194 return nil
195 },
196 })
197}
198
199func cmdSoft(user string, keys ...ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
200 return func(ts *testscript.TestScript, neg bool, args []string) {
201 cli, err := ssh.Dial(
202 "tcp",
203 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
204 &ssh.ClientConfig{
205 User: user,
206 Auth: []ssh.AuthMethod{ssh.PublicKeys(keys...)},
207 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
208 },
209 )
210 ts.Check(err)
211 defer cli.Close()
212
213 sess, err := cli.NewSession()
214 ts.Check(err)
215 defer sess.Close()
216
217 sess.Stdout = ts.Stdout()
218 sess.Stderr = ts.Stderr()
219
220 check(ts, sess.Run(strings.Join(args, " ")), neg)
221 }
222}
223
224func cmdUI(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
225 return func(ts *testscript.TestScript, neg bool, args []string) {
226 if len(args) < 1 {
227 ts.Fatalf("usage: ui <quoted string input>")
228 return
229 }
230
231 cli, err := ssh.Dial(
232 "tcp",
233 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
234 &ssh.ClientConfig{
235 User: "git",
236 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
237 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
238 },
239 )
240 check(ts, err, neg)
241 defer cli.Close()
242
243 sess, err := cli.NewSession()
244 check(ts, err, neg)
245 defer sess.Close()
246
247 // XXX: this is a hack to make the UI tests work
248 // cmp command always complains about an extra newline
249 // in the output
250 defer ts.Stdout().Write([]byte("\n"))
251
252 sess.Stdout = ts.Stdout()
253 sess.Stderr = ts.Stderr()
254
255 stdin, err := sess.StdinPipe()
256 check(ts, err, neg)
257
258 err = sess.RequestPty("dumb", 40, 80, ssh.TerminalModes{})
259 check(ts, err, neg)
260 check(ts, sess.Start(""), neg)
261
262 in, err := strconv.Unquote(args[0])
263 check(ts, err, neg)
264 reader := strings.NewReader(in)
265 go func() {
266 defer stdin.Close()
267 for {
268 r, _, err := reader.ReadRune()
269 if err == io.EOF {
270 break
271 }
272 check(ts, err, neg)
273 _, _ = io.WriteString(stdin, string(r))
274
275 // Wait for the UI to process the input
276 time.Sleep(100 * time.Millisecond)
277 }
278 }()
279
280 check(ts, sess.Wait(), neg)
281 }
282}
283
284func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
285 if neg {
286 ts.Fatalf("unsupported: ! dos2unix")
287 }
288 if len(args) < 1 {
289 ts.Fatalf("usage: dos2unix paths...")
290 }
291 for _, arg := range args {
292 filename := ts.MkAbs(arg)
293 data, err := os.ReadFile(filename)
294 if err != nil {
295 ts.Fatalf("%s: %v", filename, err)
296 }
297
298 // Replace all '\r\n' with '\n'.
299 data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
300
301 if err := os.WriteFile(filename, data, 0o644); err != nil {
302 ts.Fatalf("%s: %v", filename, err)
303 }
304 }
305}
306
307var sshConfig = `
308Host *
309 UserKnownHostsFile %q
310 StrictHostKeyChecking no
311 IdentityAgent none
312 IdentitiesOnly yes
313 ServerAliveInterval 60
314`
315
316func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
317 return func(ts *testscript.TestScript, neg bool, args []string) {
318 ts.Check(os.WriteFile(
319 ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
320 []byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
321 0o600,
322 ))
323 sshArgs := []string{
324 "-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
325 "-i", filepath.ToSlash(key),
326 }
327 ts.Setenv(
328 "GIT_SSH_COMMAND",
329 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
330 )
331 // Disable git prompting for credentials.
332 ts.Setenv("GIT_TERMINAL_PROMPT", "0")
333 args = append([]string{
334 "-c", "user.email=john@example.com",
335 "-c", "user.name=John Doe",
336 }, args...)
337 check(ts, ts.Exec("git", args...), neg)
338 }
339}
340
341func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
342 if len(args) < 2 {
343 ts.Fatalf("usage: mkfile path content")
344 }
345 check(ts, os.WriteFile(
346 ts.MkAbs(args[0]),
347 []byte(strings.Join(args[1:], " ")),
348 0o644,
349 ), neg)
350}
351
352func check(ts *testscript.TestScript, err error, neg bool) {
353 if neg && err == nil {
354 ts.Fatalf("expected error, got nil")
355 }
356 if !neg {
357 ts.Check(err)
358 }
359}
360
361func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
362 ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
363}
364
365func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {
366 if len(args) < 1 {
367 ts.Fatalf("usage: envfile key=file...")
368 }
369
370 for _, arg := range args {
371 parts := strings.SplitN(arg, "=", 2)
372 if len(parts) != 2 {
373 ts.Fatalf("usage: envfile key=file...")
374 }
375 key := parts[0]
376 file := parts[1]
377 ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
378 }
379}
380
381func cmdNewWebhook(ts *testscript.TestScript, neg bool, args []string) {
382 type webhookSite struct {
383 UUID string `json:"uuid"`
384 }
385
386 if len(args) != 1 {
387 ts.Fatalf("usage: new-webhook <env-name>")
388 }
389
390 const whSite = "https://webhook.site"
391 req, err := http.NewRequest(http.MethodPost, whSite+"/token", nil) //nolint:noctx
392 check(ts, err, neg)
393
394 resp, err := http.DefaultClient.Do(req)
395 check(ts, err, neg)
396
397 defer resp.Body.Close()
398 var site webhookSite
399 check(ts, json.NewDecoder(resp.Body).Decode(&site), neg)
400
401 ts.Setenv(args[0], whSite+"/"+site.UUID)
402}
403
404func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {
405 var verbose bool
406 var headers []string
407 var data string
408 method := http.MethodGet
409
410 cmd := &cobra.Command{
411 Use: "curl",
412 Args: cobra.MinimumNArgs(1),
413 RunE: func(cmd *cobra.Command, args []string) error {
414 url, err := url.Parse(args[0])
415 if err != nil {
416 return err
417 }
418
419 req, err := http.NewRequest(method, url.String(), nil) //nolint:noctx
420 if err != nil {
421 return err
422 }
423
424 if data != "" {
425 req.Body = io.NopCloser(strings.NewReader(data))
426 }
427
428 if verbose {
429 fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())
430 }
431
432 for _, header := range headers {
433 parts := strings.SplitN(header, ":", 2)
434 if len(parts) != 2 {
435 return fmt.Errorf("invalid header: %s", header)
436 }
437 req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
438 }
439
440 if userInfo := url.User; userInfo != nil {
441 password, _ := userInfo.Password()
442 req.SetBasicAuth(userInfo.Username(), password)
443 }
444
445 if verbose {
446 for key, values := range req.Header {
447 for _, value := range values {
448 fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)
449 }
450 }
451 }
452
453 resp, err := http.DefaultClient.Do(req)
454 if err != nil {
455 return err
456 }
457
458 if verbose {
459 fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)
460 for key, values := range resp.Header {
461 for _, value := range values {
462 fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)
463 }
464 }
465 }
466
467 defer resp.Body.Close()
468 buf, err := io.ReadAll(resp.Body)
469 if err != nil {
470 return err
471 }
472
473 cmd.Print(string(buf))
474
475 return nil
476 },
477 }
478
479 cmd.SetArgs(args)
480 cmd.SetOut(ts.Stdout())
481 cmd.SetErr(ts.Stderr())
482
483 cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")
484 cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")
485 cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")
486 cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")
487
488 check(ts, cmd.Execute(), neg)
489}
490
491func cmdEnsureServerRunning(ts *testscript.TestScript, neg bool, args []string) {
492 if len(args) < 1 {
493 ts.Fatalf("Must supply a TCP port of one of the services to connect to. " +
494 "These are set as env vars as they are randomized. " +
495 "Example usage: \"cmdensureserverrunning SSH_PORT\"\n" +
496 "Valid values for the env var: SSH_PORT|HTTP_PORT|GIT_PORT|STATS_PORT")
497 }
498
499 port := ts.Getenv(args[0])
500
501 // verify that the server is up
502 addr := net.JoinHostPort("localhost", port)
503 for {
504 conn, _ := net.DialTimeout( //nolint:noctx
505 "tcp",
506 addr,
507 time.Second,
508 )
509 if conn != nil {
510 ts.Logf("Server is running on port: %s", port)
511 conn.Close()
512 break
513 }
514 }
515}
516
517func cmdEnsureServerNotRunning(ts *testscript.TestScript, neg bool, args []string) {
518 if len(args) < 1 {
519 ts.Fatalf("Must supply a TCP port of one of the services to connect to. " +
520 "These are set as env vars as they are randomized. " +
521 "Example usage: \"cmdensureservernotrunning SSH_PORT\"\n" +
522 "Valid values for the env var: SSH_PORT|HTTP_PORT|GIT_PORT|STATS_PORT")
523 }
524
525 port := ts.Getenv(args[0])
526
527 // verify that the server is not up
528 addr := net.JoinHostPort("localhost", port)
529 conn, _ := net.DialTimeout( //nolint:noctx
530 "tcp",
531 addr,
532 time.Second,
533 )
534 if conn != nil {
535 ts.Fatalf("server is running on port %s while it should not be running", port)
536 conn.Close()
537 }
538}
539
540func cmdStopserver(ts *testscript.TestScript, neg bool, args []string) {
541 // stop the server
542 resp, err := http.DefaultClient.Head(fmt.Sprintf("%s/__stop", ts.Getenv("SOFT_SERVE_HTTP_PUBLIC_URL"))) //nolint:noctx
543 check(ts, err, neg)
544 resp.Body.Close()
545 time.Sleep(time.Second * 2) // Allow some time for the server to stop
546}
547
548func setupPostgres(t testscript.T, cfg *config.Config) (func(), error) {
549 // Indicates postgres
550 // Create a disposable database
551 rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
552 dbName := fmt.Sprintf("softserve_test_%d", rnd.Int63())
553 dbDsn := cfg.DB.DataSource
554 if dbDsn == "" {
555 cfg.DB.DataSource = "postgres://postgres@localhost:5432/postgres?sslmode=disable"
556 }
557
558 dbUrl, err := url.Parse(cfg.DB.DataSource)
559 if err != nil {
560 return nil, err
561 }
562
563 scheme := dbUrl.Scheme
564 if scheme == "" {
565 scheme = "postgres"
566 }
567
568 host := dbUrl.Hostname()
569 if host == "" {
570 host = "localhost"
571 }
572
573 connInfo := fmt.Sprintf("host=%s sslmode=disable", host)
574 username := dbUrl.User.Username()
575 if username != "" {
576 connInfo += fmt.Sprintf(" user=%s", username)
577 password, ok := dbUrl.User.Password()
578 if ok {
579 username = fmt.Sprintf("%s:%s", username, password)
580 connInfo += fmt.Sprintf(" password=%s", password)
581 }
582 username = fmt.Sprintf("%s@", username)
583 } else {
584 connInfo += " user=postgres"
585 username = "postgres@"
586 }
587
588 port := dbUrl.Port()
589 if port != "" {
590 connInfo += fmt.Sprintf(" port=%s", port)
591 port = fmt.Sprintf(":%s", port)
592 }
593
594 cfg.DB.DataSource = fmt.Sprintf("%s://%s%s%s/%s?sslmode=disable",
595 scheme,
596 username,
597 host,
598 port,
599 dbName,
600 )
601
602 // Create the database
603 dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
604 if err != nil {
605 return nil, err
606 }
607
608 if _, err := dbx.ExecContext(context.TODO(), "CREATE DATABASE "+dbName); err != nil {
609 return nil, err
610 }
611
612 return func() {
613 dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
614 if err != nil {
615 t.Fatal("failed to open database", dbName, err)
616 }
617
618 if _, err := dbx.ExecContext(context.TODO(), "DROP DATABASE "+dbName); err != nil {
619 t.Fatal("failed to drop database", dbName, err)
620 }
621 }, nil
622}
623
624type maliciousSigner struct {
625 publicKey ssh.PublicKey
626}
627
628var _ ssh.Signer = (*maliciousSigner)(nil)
629
630// PublicKey implements ssh.Signer.
631func (m *maliciousSigner) PublicKey() ssh.PublicKey {
632 return m.publicKey
633}
634
635// Sign implements ssh.Signer.
636func (m *maliciousSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
637 // The attacker doesn't know how to sign the data without a private key.
638 return &ssh.Signature{}, nil
639}