1package testscript
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "flag"
8 "fmt"
9 "io"
10 "math/rand"
11 "net"
12 "net/http"
13 "net/url"
14 "os"
15 "os/exec"
16 "path/filepath"
17 "runtime"
18 "strconv"
19 "strings"
20 "testing"
21 "time"
22
23 "github.com/charmbracelet/keygen"
24 "github.com/charmbracelet/soft-serve/pkg/config"
25 "github.com/charmbracelet/soft-serve/pkg/db"
26 "github.com/charmbracelet/soft-serve/pkg/test"
27 "github.com/rogpeppe/go-internal/testscript"
28 "github.com/spf13/cobra"
29 "golang.org/x/crypto/ssh"
30)
31
32var (
33 update = flag.Bool("update", false, "update script files")
34 binPath string
35)
36
37func TestMain(m *testing.M) {
38 tmp, err := os.MkdirTemp("", "soft-serve*")
39 if err != nil {
40 fmt.Fprintf(os.Stderr, "failed to create temporary directory: %s", err)
41 os.Exit(1)
42 }
43 defer os.RemoveAll(tmp)
44
45 binPath = filepath.Join(tmp, "soft")
46 if runtime.GOOS == "windows" {
47 binPath += ".exe"
48 }
49
50 // Build the soft binary with -cover flag.
51 cmd := exec.Command("go", "build", "-race", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft"))
52 if err := cmd.Run(); err != nil {
53 fmt.Fprintf(os.Stderr, "failed to build soft-serve binary: %s", err)
54 os.Exit(1)
55 }
56
57 // Run tests
58 os.Exit(m.Run())
59}
60
61func TestScript(t *testing.T) {
62 flag.Parse()
63
64 mkkey := func(name string) (string, *keygen.SSHKeyPair) {
65 path := filepath.Join(t.TempDir(), name)
66 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
67 if err != nil {
68 t.Fatal(err)
69 }
70 return path, pair
71 }
72
73 admin1Key, admin1 := mkkey("admin1")
74 _, admin2 := mkkey("admin2")
75 user1Key, user1 := mkkey("user1")
76
77 testscript.Run(t, testscript.Params{
78 Dir: "./testdata/",
79 UpdateScripts: *update,
80 RequireExplicitExec: true,
81 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
82 "soft": cmdSoft("admin", admin1.Signer()),
83 "usoft": cmdSoft("user1", user1.Signer()),
84 "git": cmdGit(admin1Key),
85 "ugit": cmdGit(user1Key),
86 "curl": cmdCurl,
87 "mkfile": cmdMkfile,
88 "envfile": cmdEnvfile,
89 "readfile": cmdReadfile,
90 "dos2unix": cmdDos2Unix,
91 "new-webhook": cmdNewWebhook,
92 "ensureserverrunning": cmdEnsureServerRunning,
93 "ensureservernotrunning": cmdEnsureServerNotRunning,
94 "stopserver": cmdStopserver,
95 "ui": cmdUI(admin1.Signer()),
96 "uui": cmdUI(user1.Signer()),
97 },
98 Setup: func(e *testscript.Env) error {
99 // Add binPath to PATH
100 e.Setenv("PATH", fmt.Sprintf("%s%c%s", filepath.Dir(binPath), os.PathListSeparator, e.Getenv("PATH")))
101
102 data := t.TempDir()
103 sshPort := test.RandomPort()
104 sshListen := fmt.Sprintf("localhost:%d", sshPort)
105 gitPort := test.RandomPort()
106 gitListen := fmt.Sprintf("localhost:%d", gitPort)
107 httpPort := test.RandomPort()
108 httpListen := fmt.Sprintf("localhost:%d", httpPort)
109 statsPort := test.RandomPort()
110 statsListen := fmt.Sprintf("localhost:%d", statsPort)
111 serverName := "Test Soft Serve"
112
113 e.Setenv("DATA_PATH", data)
114 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
115 e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
116 e.Setenv("STATS_PORT", fmt.Sprintf("%d", statsPort))
117 e.Setenv("GIT_PORT", fmt.Sprintf("%d", gitPort))
118 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
119 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
120 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
121 e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
122 e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
123
124 // This is used to set up test specific configuration and http endpoints
125 e.Setenv("SOFT_SERVE_TESTRUN", "1")
126
127 // This will disable the default lipgloss renderer colors
128 e.Setenv("SOFT_SERVE_NO_COLOR", "1")
129
130 // Soft Serve debug environment variables
131 for _, env := range []string{
132 "SOFT_SERVE_DEBUG",
133 "SOFT_SERVE_VERBOSE",
134 } {
135 if v, ok := os.LookupEnv(env); ok {
136 e.Setenv(env, v)
137 }
138 }
139
140 // TODO: test different configs
141 cfg := config.DefaultConfig()
142 cfg.DataPath = data
143 cfg.Name = serverName
144 cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
145 cfg.SSH.ListenAddr = sshListen
146 cfg.SSH.PublicURL = "ssh://" + sshListen
147 cfg.Git.ListenAddr = gitListen
148 cfg.HTTP.ListenAddr = httpListen
149 cfg.HTTP.PublicURL = "http://" + httpListen
150 cfg.Stats.ListenAddr = statsListen
151 cfg.LFS.Enabled = true
152
153 // Parse os SOFT_SERVE environment variables
154 if err := cfg.ParseEnv(); err != nil {
155 return err
156 }
157
158 // Override the database data source if we're using postgres
159 // so we can create a temporary database for the tests.
160 if cfg.DB.Driver == "postgres" {
161 err, cleanup := setupPostgres(e.T(), cfg)
162 if err != nil {
163 return err
164 }
165 if cleanup != nil {
166 e.Defer(cleanup)
167 }
168 }
169
170 for _, env := range cfg.Environ() {
171 parts := strings.SplitN(env, "=", 2)
172 if len(parts) != 2 {
173 e.T().Fatal("invalid environment variable", env)
174 }
175 e.Setenv(parts[0], parts[1])
176 }
177
178 return nil
179 },
180 })
181}
182
183func cmdSoft(user string, key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
184 return func(ts *testscript.TestScript, neg bool, args []string) {
185 cli, err := ssh.Dial(
186 "tcp",
187 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
188 &ssh.ClientConfig{
189 User: user,
190 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
191 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
192 },
193 )
194 ts.Check(err)
195 defer cli.Close()
196
197 sess, err := cli.NewSession()
198 ts.Check(err)
199 defer sess.Close()
200
201 sess.Stdout = ts.Stdout()
202 sess.Stderr = ts.Stderr()
203
204 check(ts, sess.Run(strings.Join(args, " ")), neg)
205 }
206}
207
208func cmdUI(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
209 return func(ts *testscript.TestScript, neg bool, args []string) {
210 if len(args) < 1 {
211 ts.Fatalf("usage: ui <quoted string input>")
212 return
213 }
214
215 cli, err := ssh.Dial(
216 "tcp",
217 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
218 &ssh.ClientConfig{
219 User: "git",
220 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
221 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
222 },
223 )
224 check(ts, err, neg)
225 defer cli.Close()
226
227 sess, err := cli.NewSession()
228 check(ts, err, neg)
229 defer sess.Close()
230
231 // XXX: this is a hack to make the UI tests work
232 // cmp command always complains about an extra newline
233 // in the output
234 defer ts.Stdout().Write([]byte("\n"))
235
236 sess.Stdout = ts.Stdout()
237 sess.Stderr = ts.Stderr()
238
239 stdin, err := sess.StdinPipe()
240 check(ts, err, neg)
241
242 err = sess.RequestPty("dumb", 40, 80, ssh.TerminalModes{})
243 check(ts, err, neg)
244 check(ts, sess.Start(""), neg)
245
246 in, err := strconv.Unquote(args[0])
247 check(ts, err, neg)
248 reader := strings.NewReader(in)
249 go func() {
250 defer stdin.Close()
251 for {
252 r, _, err := reader.ReadRune()
253 if err == io.EOF {
254 break
255 }
256 check(ts, err, neg)
257 stdin.Write([]byte(string(r))) // nolint: errcheck
258
259 // Wait for the UI to process the input
260 time.Sleep(100 * time.Millisecond)
261 }
262 }()
263
264 check(ts, sess.Wait(), neg)
265 }
266}
267
268// P.S. Windows sucks!
269func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
270 if neg {
271 ts.Fatalf("unsupported: ! dos2unix")
272 }
273 if len(args) < 1 {
274 ts.Fatalf("usage: dos2unix paths...")
275 }
276 for _, arg := range args {
277 filename := ts.MkAbs(arg)
278 data, err := os.ReadFile(filename)
279 if err != nil {
280 ts.Fatalf("%s: %v", filename, err)
281 }
282
283 // Replace all '\r\n' with '\n'.
284 data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
285
286 if err := os.WriteFile(filename, data, 0o644); err != nil {
287 ts.Fatalf("%s: %v", filename, err)
288 }
289 }
290}
291
292var sshConfig = `
293Host *
294 UserKnownHostsFile %q
295 StrictHostKeyChecking no
296 IdentityAgent none
297 IdentitiesOnly yes
298 ServerAliveInterval 60
299`
300
301func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
302 return func(ts *testscript.TestScript, neg bool, args []string) {
303 ts.Check(os.WriteFile(
304 ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
305 []byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
306 0o600,
307 ))
308 sshArgs := []string{
309 "-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
310 "-i", filepath.ToSlash(key),
311 }
312 ts.Setenv(
313 "GIT_SSH_COMMAND",
314 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
315 )
316 // Disable git prompting for credentials.
317 ts.Setenv("GIT_TERMINAL_PROMPT", "0")
318 args = append([]string{
319 "-c", "user.email=john@example.com",
320 "-c", "user.name=John Doe",
321 }, args...)
322 check(ts, ts.Exec("git", args...), neg)
323 }
324}
325
326func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
327 if len(args) < 2 {
328 ts.Fatalf("usage: mkfile path content")
329 }
330 check(ts, os.WriteFile(
331 ts.MkAbs(args[0]),
332 []byte(strings.Join(args[1:], " ")),
333 0o644,
334 ), neg)
335}
336
337func check(ts *testscript.TestScript, err error, neg bool) {
338 if neg && err == nil {
339 ts.Fatalf("expected error, got nil")
340 }
341 if !neg {
342 ts.Check(err)
343 }
344}
345
346func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
347 ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
348}
349
350func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {
351 if len(args) < 1 {
352 ts.Fatalf("usage: envfile key=file...")
353 }
354
355 for _, arg := range args {
356 parts := strings.SplitN(arg, "=", 2)
357 if len(parts) != 2 {
358 ts.Fatalf("usage: envfile key=file...")
359 }
360 key := parts[0]
361 file := parts[1]
362 ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
363 }
364}
365
366func cmdNewWebhook(ts *testscript.TestScript, neg bool, args []string) {
367 type webhookSite struct {
368 UUID string `json:"uuid"`
369 }
370
371 if len(args) != 1 {
372 ts.Fatalf("usage: new-webhook <env-name>")
373 }
374
375 const whSite = "https://webhook.site"
376 req, err := http.NewRequest(http.MethodPost, whSite+"/token", nil)
377 check(ts, err, neg)
378
379 resp, err := http.DefaultClient.Do(req)
380 check(ts, err, neg)
381
382 defer resp.Body.Close()
383 var site webhookSite
384 check(ts, json.NewDecoder(resp.Body).Decode(&site), neg)
385
386 ts.Setenv(args[0], whSite+"/"+site.UUID)
387}
388
389func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {
390 var verbose bool
391 var headers []string
392 var data string
393 method := http.MethodGet
394
395 cmd := &cobra.Command{
396 Use: "curl",
397 Args: cobra.MinimumNArgs(1),
398 RunE: func(cmd *cobra.Command, args []string) error {
399 url, err := url.Parse(args[0])
400 if err != nil {
401 return err
402 }
403
404 req, err := http.NewRequest(method, url.String(), nil)
405 if err != nil {
406 return err
407 }
408
409 if data != "" {
410 req.Body = io.NopCloser(strings.NewReader(data))
411 }
412
413 if verbose {
414 fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())
415 }
416
417 for _, header := range headers {
418 parts := strings.SplitN(header, ":", 2)
419 if len(parts) != 2 {
420 return fmt.Errorf("invalid header: %s", header)
421 }
422 req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
423 }
424
425 if userInfo := url.User; userInfo != nil {
426 password, _ := userInfo.Password()
427 req.SetBasicAuth(userInfo.Username(), password)
428 }
429
430 if verbose {
431 for key, values := range req.Header {
432 for _, value := range values {
433 fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)
434 }
435 }
436 }
437
438 resp, err := http.DefaultClient.Do(req)
439 if err != nil {
440 return err
441 }
442
443 if verbose {
444 fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)
445 for key, values := range resp.Header {
446 for _, value := range values {
447 fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)
448 }
449 }
450 }
451
452 defer resp.Body.Close()
453 buf, err := io.ReadAll(resp.Body)
454 if err != nil {
455 return err
456 }
457
458 cmd.Print(string(buf))
459
460 return nil
461 },
462 }
463
464 cmd.SetArgs(args)
465 cmd.SetOut(ts.Stdout())
466 cmd.SetErr(ts.Stderr())
467
468 cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")
469 cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")
470 cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")
471 cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")
472
473 check(ts, cmd.Execute(), neg)
474}
475
476func cmdEnsureServerRunning(ts *testscript.TestScript, neg bool, args []string) {
477 if len(args) < 1 {
478 ts.Fatalf("Must supply a TCP port of one of the services to connect to. " +
479 "These are set as env vars as they are randomized. " +
480 "Example usage: \"cmdensureserverrunning SSH_PORT\"\n" +
481 "Valid values for the env var: SSH_PORT|HTTP_PORT|GIT_PORT|STATS_PORT")
482 }
483
484 port := ts.Getenv(args[0])
485
486 // verify that the server is up
487 addr := net.JoinHostPort("localhost", port)
488 for {
489 conn, _ := net.DialTimeout(
490 "tcp",
491 addr,
492 time.Second,
493 )
494 if conn != nil {
495 ts.Logf("Server is running on port: %s", port)
496 conn.Close()
497 break
498 }
499 }
500}
501
502func cmdEnsureServerNotRunning(ts *testscript.TestScript, neg bool, args []string) {
503 if len(args) < 1 {
504 ts.Fatalf("Must supply a TCP port of one of the services to connect to. " +
505 "These are set as env vars as they are randomized. " +
506 "Example usage: \"cmdensureservernotrunning SSH_PORT\"\n" +
507 "Valid values for the env var: SSH_PORT|HTTP_PORT|GIT_PORT|STATS_PORT")
508 }
509
510 port := ts.Getenv(args[0])
511
512 // verify that the server is not up
513 addr := net.JoinHostPort("localhost", port)
514 for {
515 conn, _ := net.DialTimeout(
516 "tcp",
517 addr,
518 time.Second,
519 )
520 if conn != nil {
521 ts.Fatalf("server is running on port %s while it should not be running", port)
522 conn.Close()
523 }
524 break
525 }
526}
527
528func cmdStopserver(ts *testscript.TestScript, neg bool, args []string) {
529 // stop the server
530 resp, err := http.DefaultClient.Head(fmt.Sprintf("%s/__stop", ts.Getenv("SOFT_SERVE_HTTP_PUBLIC_URL")))
531 check(ts, err, neg)
532 resp.Body.Close()
533 time.Sleep(time.Second * 2) // Allow some time for the server to stop
534}
535
536func setupPostgres(t testscript.T, cfg *config.Config) (error, func()) {
537 // Indicates postgres
538 // Create a disposable database
539 rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
540 dbName := fmt.Sprintf("softserve_test_%d", rnd.Int63())
541 dbDsn := cfg.DB.DataSource
542 if dbDsn == "" {
543 cfg.DB.DataSource = "postgres://postgres@localhost:5432/postgres?sslmode=disable"
544 }
545
546 dbUrl, err := url.Parse(cfg.DB.DataSource)
547 if err != nil {
548 return err, nil
549 }
550
551 scheme := dbUrl.Scheme
552 if scheme == "" {
553 scheme = "postgres"
554 }
555
556 host := dbUrl.Hostname()
557 if host == "" {
558 host = "localhost"
559 }
560
561 connInfo := fmt.Sprintf("host=%s sslmode=disable", host)
562 username := dbUrl.User.Username()
563 if username != "" {
564 connInfo += fmt.Sprintf(" user=%s", username)
565 password, ok := dbUrl.User.Password()
566 if ok {
567 username = fmt.Sprintf("%s:%s", username, password)
568 connInfo += fmt.Sprintf(" password=%s", password)
569 }
570 username = fmt.Sprintf("%s@", username)
571 } else {
572 connInfo += " user=postgres"
573 username = "postgres@"
574 }
575
576 port := dbUrl.Port()
577 if port != "" {
578 connInfo += fmt.Sprintf(" port=%s", port)
579 port = fmt.Sprintf(":%s", port)
580 }
581
582 cfg.DB.DataSource = fmt.Sprintf("%s://%s%s%s/%s?sslmode=disable",
583 scheme,
584 username,
585 host,
586 port,
587 dbName,
588 )
589
590 // Create the database
591 dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
592 if err != nil {
593 return err, nil
594 }
595
596 if _, err := dbx.Exec("CREATE DATABASE " + dbName); err != nil {
597 return err, nil
598 }
599
600 return nil, func() {
601 dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
602 if err != nil {
603 t.Fatal("failed to open database", dbName, err)
604 }
605
606 if _, err := dbx.Exec("DROP DATABASE " + dbName); err != nil {
607 t.Fatal("failed to drop database", dbName, err)
608 }
609 }
610}