1package testscript
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "flag"
8 "fmt"
9 "io"
10 "math/rand"
11 "net"
12 "net/http"
13 "net/url"
14 "os"
15 "os/exec"
16 "path/filepath"
17 "runtime"
18 "strconv"
19 "strings"
20 "testing"
21 "time"
22
23 "github.com/charmbracelet/keygen"
24 "github.com/charmbracelet/soft-serve/pkg/config"
25 "github.com/charmbracelet/soft-serve/pkg/db"
26 "github.com/charmbracelet/soft-serve/pkg/test"
27 "github.com/rogpeppe/go-internal/testscript"
28 "github.com/spf13/cobra"
29 "golang.org/x/crypto/ssh"
30)
31
32var (
33 update = flag.Bool("update", false, "update script files")
34 binPath string
35)
36
37func TestMain(m *testing.M) {
38 tmp, err := os.MkdirTemp("", "soft-serve*")
39 if err != nil {
40 fmt.Fprintf(os.Stderr, "failed to create temporary directory: %s", err)
41 os.Exit(1)
42 }
43 defer os.RemoveAll(tmp)
44
45 binPath = filepath.Join(tmp, "soft")
46 if runtime.GOOS == "windows" {
47 binPath += ".exe"
48 }
49
50 // Build the soft binary with -cover flag.
51 cmd := exec.Command("go", "build", "-race", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft"))
52 if err := cmd.Run(); err != nil {
53 fmt.Fprintf(os.Stderr, "failed to build soft-serve binary: %s", err)
54 os.Exit(1)
55 }
56
57 // Run tests
58 os.Exit(m.Run())
59}
60
61func TestScript(t *testing.T) {
62 flag.Parse()
63
64 mkkey := func(name string) (string, *keygen.SSHKeyPair) {
65 path := filepath.Join(t.TempDir(), name)
66 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
67 if err != nil {
68 t.Fatal(err)
69 }
70 return path, pair
71 }
72
73 admin1Key, admin1 := mkkey("admin1")
74 _, admin2 := mkkey("admin2")
75 user1Key, user1 := mkkey("user1")
76
77 testscript.Run(t, testscript.Params{
78 Dir: "./testdata/",
79 UpdateScripts: *update,
80 RequireExplicitExec: true,
81 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
82 "soft": cmdSoft("admin", admin1.Signer()),
83 "usoft": cmdSoft("user1", user1.Signer()),
84 "git": cmdGit(admin1Key),
85 "ugit": cmdGit(user1Key),
86 "curl": cmdCurl,
87 "mkfile": cmdMkfile,
88 "envfile": cmdEnvfile,
89 "readfile": cmdReadfile,
90 "dos2unix": cmdDos2Unix,
91 "new-webhook": cmdNewWebhook,
92 "ensureserverrunning": cmdEnsureServerRunning,
93 "stopserver": cmdStopserver,
94 "ui": cmdUI(admin1.Signer()),
95 "uui": cmdUI(user1.Signer()),
96 },
97 Setup: func(e *testscript.Env) error {
98 // Add binPath to PATH
99 e.Setenv("PATH", fmt.Sprintf("%s%c%s", filepath.Dir(binPath), os.PathListSeparator, e.Getenv("PATH")))
100
101 data := t.TempDir()
102 sshPort := test.RandomPort()
103 sshListen := fmt.Sprintf("localhost:%d", sshPort)
104 gitPort := test.RandomPort()
105 gitListen := fmt.Sprintf("localhost:%d", gitPort)
106 httpPort := test.RandomPort()
107 httpListen := fmt.Sprintf("localhost:%d", httpPort)
108 statsPort := test.RandomPort()
109 statsListen := fmt.Sprintf("localhost:%d", statsPort)
110 serverName := "Test Soft Serve"
111
112 e.Setenv("DATA_PATH", data)
113 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
114 e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
115 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
116 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
117 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
118 e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
119 e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
120
121 // This is used to set up test specific configuration and http endpoints
122 e.Setenv("SOFT_SERVE_TESTRUN", "1")
123
124 // This will disable the default lipgloss renderer colors
125 e.Setenv("SOFT_SERVE_NO_COLOR", "1")
126
127 // Soft Serve debug environment variables
128 for _, env := range []string{
129 "SOFT_SERVE_DEBUG",
130 "SOFT_SERVE_VERBOSE",
131 } {
132 if v, ok := os.LookupEnv(env); ok {
133 e.Setenv(env, v)
134 }
135 }
136
137 // TODO: test different configs
138 cfg := config.DefaultConfig()
139 cfg.DataPath = data
140 cfg.Name = serverName
141 cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
142 cfg.SSH.ListenAddr = sshListen
143 cfg.SSH.PublicURL = "ssh://" + sshListen
144 cfg.Git.ListenAddr = gitListen
145 cfg.HTTP.ListenAddr = httpListen
146 cfg.HTTP.PublicURL = "http://" + httpListen
147 cfg.Stats.ListenAddr = statsListen
148 cfg.LFS.Enabled = true
149
150 // Parse os SOFT_SERVE environment variables
151 if err := cfg.ParseEnv(); err != nil {
152 return err
153 }
154
155 // Override the database data source if we're using postgres
156 // so we can create a temporary database for the tests.
157 if cfg.DB.Driver == "postgres" {
158 err, cleanup := setupPostgres(e.T(), cfg)
159 if err != nil {
160 return err
161 }
162 if cleanup != nil {
163 e.Defer(cleanup)
164 }
165 }
166
167 for _, env := range cfg.Environ() {
168 parts := strings.SplitN(env, "=", 2)
169 if len(parts) != 2 {
170 e.T().Fatal("invalid environment variable", env)
171 }
172 e.Setenv(parts[0], parts[1])
173 }
174
175 return nil
176 },
177 })
178}
179
180func cmdSoft(user string, key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
181 return func(ts *testscript.TestScript, neg bool, args []string) {
182 cli, err := ssh.Dial(
183 "tcp",
184 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
185 &ssh.ClientConfig{
186 User: user,
187 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
188 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
189 },
190 )
191 ts.Check(err)
192 defer cli.Close()
193
194 sess, err := cli.NewSession()
195 ts.Check(err)
196 defer sess.Close()
197
198 sess.Stdout = ts.Stdout()
199 sess.Stderr = ts.Stderr()
200
201 check(ts, sess.Run(strings.Join(args, " ")), neg)
202 }
203}
204
205func cmdUI(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
206 return func(ts *testscript.TestScript, neg bool, args []string) {
207 if len(args) < 1 {
208 ts.Fatalf("usage: ui <quoted string input>")
209 return
210 }
211
212 cli, err := ssh.Dial(
213 "tcp",
214 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
215 &ssh.ClientConfig{
216 User: "git",
217 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
218 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
219 },
220 )
221 check(ts, err, neg)
222 defer cli.Close()
223
224 sess, err := cli.NewSession()
225 check(ts, err, neg)
226 defer sess.Close()
227
228 // XXX: this is a hack to make the UI tests work
229 // cmp command always complains about an extra newline
230 // in the output
231 defer ts.Stdout().Write([]byte("\n"))
232
233 sess.Stdout = ts.Stdout()
234 sess.Stderr = ts.Stderr()
235
236 stdin, err := sess.StdinPipe()
237 check(ts, err, neg)
238
239 err = sess.RequestPty("dumb", 40, 80, ssh.TerminalModes{})
240 check(ts, err, neg)
241 check(ts, sess.Start(""), neg)
242
243 in, err := strconv.Unquote(args[0])
244 check(ts, err, neg)
245 reader := strings.NewReader(in)
246 go func() {
247 defer stdin.Close()
248 for {
249 r, _, err := reader.ReadRune()
250 if err == io.EOF {
251 break
252 }
253 check(ts, err, neg)
254 stdin.Write([]byte(string(r))) // nolint: errcheck
255
256 // Wait for the UI to process the input
257 time.Sleep(100 * time.Millisecond)
258 }
259 }()
260
261 check(ts, sess.Wait(), neg)
262 }
263}
264
265// P.S. Windows sucks!
266func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
267 if neg {
268 ts.Fatalf("unsupported: ! dos2unix")
269 }
270 if len(args) < 1 {
271 ts.Fatalf("usage: dos2unix paths...")
272 }
273 for _, arg := range args {
274 filename := ts.MkAbs(arg)
275 data, err := os.ReadFile(filename)
276 if err != nil {
277 ts.Fatalf("%s: %v", filename, err)
278 }
279
280 // Replace all '\r\n' with '\n'.
281 data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
282
283 if err := os.WriteFile(filename, data, 0o644); err != nil {
284 ts.Fatalf("%s: %v", filename, err)
285 }
286 }
287}
288
289var sshConfig = `
290Host *
291 UserKnownHostsFile %q
292 StrictHostKeyChecking no
293 IdentityAgent none
294 IdentitiesOnly yes
295 ServerAliveInterval 60
296`
297
298func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
299 return func(ts *testscript.TestScript, neg bool, args []string) {
300 ts.Check(os.WriteFile(
301 ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
302 []byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
303 0o600,
304 ))
305 sshArgs := []string{
306 "-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
307 "-i", filepath.ToSlash(key),
308 }
309 ts.Setenv(
310 "GIT_SSH_COMMAND",
311 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
312 )
313 // Disable git prompting for credentials.
314 ts.Setenv("GIT_TERMINAL_PROMPT", "0")
315 args = append([]string{
316 "-c", "user.email=john@example.com",
317 "-c", "user.name=John Doe",
318 }, args...)
319 check(ts, ts.Exec("git", args...), neg)
320 }
321}
322
323func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
324 if len(args) < 2 {
325 ts.Fatalf("usage: mkfile path content")
326 }
327 check(ts, os.WriteFile(
328 ts.MkAbs(args[0]),
329 []byte(strings.Join(args[1:], " ")),
330 0o644,
331 ), neg)
332}
333
334func check(ts *testscript.TestScript, err error, neg bool) {
335 if neg && err == nil {
336 ts.Fatalf("expected error, got nil")
337 }
338 if !neg {
339 ts.Check(err)
340 }
341}
342
343func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
344 ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
345}
346
347func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {
348 if len(args) < 1 {
349 ts.Fatalf("usage: envfile key=file...")
350 }
351
352 for _, arg := range args {
353 parts := strings.SplitN(arg, "=", 2)
354 if len(parts) != 2 {
355 ts.Fatalf("usage: envfile key=file...")
356 }
357 key := parts[0]
358 file := parts[1]
359 ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
360 }
361}
362
363func cmdNewWebhook(ts *testscript.TestScript, neg bool, args []string) {
364 type webhookSite struct {
365 UUID string `json:"uuid"`
366 }
367
368 if len(args) != 1 {
369 ts.Fatalf("usage: new-webhook <env-name>")
370 }
371
372 const whSite = "https://webhook.site"
373 req, err := http.NewRequest(http.MethodPost, whSite+"/token", nil)
374 check(ts, err, neg)
375
376 resp, err := http.DefaultClient.Do(req)
377 check(ts, err, neg)
378
379 defer resp.Body.Close()
380 var site webhookSite
381 check(ts, json.NewDecoder(resp.Body).Decode(&site), neg)
382
383 ts.Setenv(args[0], whSite+"/"+site.UUID)
384}
385
386func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {
387 var verbose bool
388 var headers []string
389 var data string
390 method := http.MethodGet
391
392 cmd := &cobra.Command{
393 Use: "curl",
394 Args: cobra.MinimumNArgs(1),
395 RunE: func(cmd *cobra.Command, args []string) error {
396 url, err := url.Parse(args[0])
397 if err != nil {
398 return err
399 }
400
401 req, err := http.NewRequest(method, url.String(), nil)
402 if err != nil {
403 return err
404 }
405
406 if data != "" {
407 req.Body = io.NopCloser(strings.NewReader(data))
408 }
409
410 if verbose {
411 fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())
412 }
413
414 for _, header := range headers {
415 parts := strings.SplitN(header, ":", 2)
416 if len(parts) != 2 {
417 return fmt.Errorf("invalid header: %s", header)
418 }
419 req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
420 }
421
422 if userInfo := url.User; userInfo != nil {
423 password, _ := userInfo.Password()
424 req.SetBasicAuth(userInfo.Username(), password)
425 }
426
427 if verbose {
428 for key, values := range req.Header {
429 for _, value := range values {
430 fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)
431 }
432 }
433 }
434
435 resp, err := http.DefaultClient.Do(req)
436 if err != nil {
437 return err
438 }
439
440 if verbose {
441 fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)
442 for key, values := range resp.Header {
443 for _, value := range values {
444 fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)
445 }
446 }
447 }
448
449 defer resp.Body.Close()
450 buf, err := io.ReadAll(resp.Body)
451 if err != nil {
452 return err
453 }
454
455 cmd.Print(string(buf))
456
457 return nil
458 },
459 }
460
461 cmd.SetArgs(args)
462 cmd.SetOut(ts.Stdout())
463 cmd.SetErr(ts.Stderr())
464
465 cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")
466 cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")
467 cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")
468 cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")
469
470 check(ts, cmd.Execute(), neg)
471}
472
473func cmdEnsureServerRunning(ts *testscript.TestScript, neg bool, args []string) {
474 if len(args) < 1 {
475 ts.Fatalf("Must supply a TCP port of one of the services to connect to. " +
476 "These are set as env vars as they are randomized. " +
477 "Example usage: \"cmdensureserverrunning SSH_PORT\"\n" +
478 "Valid values for the env var: SSH_PORT|HTTP_PORT|GIT_PORT|STATS_PORT")
479 }
480
481 port := ts.Getenv(args[0])
482
483 // verify that the server is up
484 addr := net.JoinHostPort("localhost", port)
485 for {
486 conn, _ := net.DialTimeout(
487 "tcp",
488 addr,
489 time.Second,
490 )
491 if conn != nil {
492 ts.Logf("Server is running on port: %s", port)
493 conn.Close()
494 break
495 }
496 }
497}
498
499func cmdStopserver(ts *testscript.TestScript, neg bool, args []string) {
500 // stop the server
501 resp, err := http.DefaultClient.Head(fmt.Sprintf("%s/__stop", ts.Getenv("SOFT_SERVE_HTTP_PUBLIC_URL")))
502 check(ts, err, neg)
503 resp.Body.Close()
504 time.Sleep(time.Second * 2) // Allow some time for the server to stop
505}
506
507func setupPostgres(t testscript.T, cfg *config.Config) (error, func()) {
508 // Indicates postgres
509 // Create a disposable database
510 rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
511 dbName := fmt.Sprintf("softserve_test_%d", rnd.Int63())
512 dbDsn := cfg.DB.DataSource
513 if dbDsn == "" {
514 cfg.DB.DataSource = "postgres://postgres@localhost:5432/postgres?sslmode=disable"
515 }
516
517 dbUrl, err := url.Parse(cfg.DB.DataSource)
518 if err != nil {
519 return err, nil
520 }
521
522 scheme := dbUrl.Scheme
523 if scheme == "" {
524 scheme = "postgres"
525 }
526
527 host := dbUrl.Hostname()
528 if host == "" {
529 host = "localhost"
530 }
531
532 connInfo := fmt.Sprintf("host=%s sslmode=disable", host)
533 username := dbUrl.User.Username()
534 if username != "" {
535 connInfo += fmt.Sprintf(" user=%s", username)
536 password, ok := dbUrl.User.Password()
537 if ok {
538 username = fmt.Sprintf("%s:%s", username, password)
539 connInfo += fmt.Sprintf(" password=%s", password)
540 }
541 username = fmt.Sprintf("%s@", username)
542 } else {
543 connInfo += " user=postgres"
544 username = "postgres@"
545 }
546
547 port := dbUrl.Port()
548 if port != "" {
549 connInfo += fmt.Sprintf(" port=%s", port)
550 port = fmt.Sprintf(":%s", port)
551 }
552
553 cfg.DB.DataSource = fmt.Sprintf("%s://%s%s%s/%s?sslmode=disable",
554 scheme,
555 username,
556 host,
557 port,
558 dbName,
559 )
560
561 // Create the database
562 dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
563 if err != nil {
564 return err, nil
565 }
566
567 if _, err := dbx.Exec("CREATE DATABASE " + dbName); err != nil {
568 return err, nil
569 }
570
571 return nil, func() {
572 dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
573 if err != nil {
574 t.Fatal("failed to open database", dbName, err)
575 }
576
577 if _, err := dbx.Exec("DROP DATABASE " + dbName); err != nil {
578 t.Fatal("failed to drop database", dbName, err)
579 }
580 }
581}