1package testscript
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "flag"
8 "fmt"
9 "io"
10 "math/rand"
11 "net"
12 "net/http"
13 "net/url"
14 "os"
15 "os/exec"
16 "path/filepath"
17 "runtime"
18 "strconv"
19 "strings"
20 "testing"
21 "time"
22
23 "github.com/charmbracelet/keygen"
24 "github.com/charmbracelet/soft-serve/pkg/config"
25 "github.com/charmbracelet/soft-serve/pkg/db"
26 "github.com/charmbracelet/soft-serve/pkg/test"
27 "github.com/rogpeppe/go-internal/testscript"
28 "github.com/spf13/cobra"
29 "golang.org/x/crypto/ssh"
30)
31
32var (
33 update = flag.Bool("update", false, "update script files")
34 binPath string
35)
36
37func TestMain(m *testing.M) {
38 tmp, err := os.MkdirTemp("", "soft-serve*")
39 if err != nil {
40 fmt.Fprintf(os.Stderr, "failed to create temporary directory: %s", err)
41 os.Exit(1)
42 }
43 defer os.RemoveAll(tmp)
44
45 binPath = filepath.Join(tmp, "soft")
46 if runtime.GOOS == "windows" {
47 binPath += ".exe"
48 }
49
50 // Build the soft binary with -cover flag.
51 cmd := exec.Command("go", "build", "-race", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft"))
52 if err := cmd.Run(); err != nil {
53 fmt.Fprintf(os.Stderr, "failed to build soft-serve binary: %s", err)
54 os.Exit(1)
55 }
56
57 // Run tests
58 os.Exit(m.Run())
59
60 // Add binPath to PATH
61 os.Setenv("PATH", fmt.Sprintf("%s%c%s", os.Getenv("PATH"), os.PathListSeparator, filepath.Dir(binPath)))
62}
63
64func TestScript(t *testing.T) {
65 flag.Parse()
66
67 mkkey := func(name string) (string, *keygen.SSHKeyPair) {
68 path := filepath.Join(t.TempDir(), name)
69 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
70 if err != nil {
71 t.Fatal(err)
72 }
73 return path, pair
74 }
75
76 key, admin1 := mkkey("admin1")
77 _, admin2 := mkkey("admin2")
78 _, user1 := mkkey("user1")
79 _, user2 := mkkey("user2")
80
81 testscript.Run(t, testscript.Params{
82 Dir: "./testdata/",
83 UpdateScripts: *update,
84 RequireExplicitExec: true,
85 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
86 "soft": cmdSoft(admin1.Signer()),
87 "usoft": cmdSoft(user1.Signer()),
88 "git": cmdGit(key),
89 "curl": cmdCurl,
90 "mkfile": cmdMkfile,
91 "envfile": cmdEnvfile,
92 "readfile": cmdReadfile,
93 "dos2unix": cmdDos2Unix,
94 "new-webhook": cmdNewWebhook,
95 "waitforserver": cmdWaitforserver,
96 "stopserver": cmdStopserver,
97 "ui": cmdUI(admin1.Signer()),
98 "uui": cmdUI(user1.Signer()),
99 },
100 Setup: func(e *testscript.Env) error {
101 // Add binPath to PATH
102 e.Setenv("PATH", fmt.Sprintf("%s%c%s", filepath.Dir(binPath), os.PathListSeparator, e.Getenv("PATH")))
103
104 data := t.TempDir()
105 sshPort := test.RandomPort()
106 sshListen := fmt.Sprintf("localhost:%d", sshPort)
107 gitPort := test.RandomPort()
108 gitListen := fmt.Sprintf("localhost:%d", gitPort)
109 httpPort := test.RandomPort()
110 httpListen := fmt.Sprintf("localhost:%d", httpPort)
111 statsPort := test.RandomPort()
112 statsListen := fmt.Sprintf("localhost:%d", statsPort)
113 serverName := "Test Soft Serve"
114
115 e.Setenv("DATA_PATH", data)
116 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
117 e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
118 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
119 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
120 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
121 e.Setenv("USER2_AUTHORIZED_KEY", user2.AuthorizedKey())
122 e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
123 e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
124
125 // This is used to set up test specific configuration and http endpoints
126 e.Setenv("SOFT_SERVE_TESTRUN", "1")
127
128 // This will disable the default lipgloss renderer colors
129 e.Setenv("SOFT_SERVE_NO_COLOR", "1")
130
131 // Soft Serve debug environment variables
132 for _, env := range []string{
133 "SOFT_SERVE_DEBUG",
134 "SOFT_SERVE_VERBOSE",
135 } {
136 if v, ok := os.LookupEnv(env); ok {
137 e.Setenv(env, v)
138 }
139 }
140
141 // TODO: test different configs
142 cfg := config.DefaultConfig()
143 cfg.DataPath = data
144 cfg.Name = serverName
145 cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
146 cfg.SSH.ListenAddr = sshListen
147 cfg.SSH.PublicURL = "ssh://" + sshListen
148 cfg.Git.ListenAddr = gitListen
149 cfg.HTTP.ListenAddr = httpListen
150 cfg.HTTP.PublicURL = "http://" + httpListen
151 cfg.Stats.ListenAddr = statsListen
152 cfg.LFS.Enabled = true
153
154 // Parse os SOFT_SERVE environment variables
155 if err := cfg.ParseEnv(); err != nil {
156 return err
157 }
158
159 // Override the database data source if we're using postgres
160 // so we can create a temporary database for the tests.
161 if cfg.DB.Driver == "postgres" {
162 err, cleanup := setupPostgres(e.T(), cfg)
163 if err != nil {
164 return err
165 }
166 if cleanup != nil {
167 e.Defer(cleanup)
168 }
169 }
170
171 for _, env := range cfg.Environ() {
172 parts := strings.SplitN(env, "=", 2)
173 if len(parts) != 2 {
174 e.T().Fatal("invalid environment variable", env)
175 }
176 e.Setenv(parts[0], parts[1])
177 }
178
179 return nil
180 },
181 })
182}
183
184func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
185 return func(ts *testscript.TestScript, neg bool, args []string) {
186 cli, err := ssh.Dial(
187 "tcp",
188 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
189 &ssh.ClientConfig{
190 User: "admin",
191 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
192 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
193 },
194 )
195 ts.Check(err)
196 defer cli.Close()
197
198 sess, err := cli.NewSession()
199 ts.Check(err)
200 defer sess.Close()
201
202 sess.Stdout = ts.Stdout()
203 sess.Stderr = ts.Stderr()
204
205 check(ts, sess.Run(strings.Join(args, " ")), neg)
206 }
207}
208
209func cmdUI(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
210 return func(ts *testscript.TestScript, neg bool, args []string) {
211 if len(args) < 1 {
212 ts.Fatalf("usage: ui <quoted string input>")
213 return
214 }
215
216 cli, err := ssh.Dial(
217 "tcp",
218 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
219 &ssh.ClientConfig{
220 User: "git",
221 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
222 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
223 },
224 )
225 check(ts, err, neg)
226 defer cli.Close()
227
228 sess, err := cli.NewSession()
229 check(ts, err, neg)
230 defer sess.Close()
231
232 // XXX: this is a hack to make the UI tests work
233 // cmp command always complains about an extra newline
234 // in the output
235 defer ts.Stdout().Write([]byte("\n"))
236
237 sess.Stdout = ts.Stdout()
238 sess.Stderr = ts.Stderr()
239
240 stdin, err := sess.StdinPipe()
241 check(ts, err, neg)
242
243 in, err := strconv.Unquote(args[0])
244 check(ts, err, neg)
245 reader := strings.NewReader(in)
246 go func() {
247 defer stdin.Close()
248 for {
249 r, _, err := reader.ReadRune()
250 if err == io.EOF {
251 break
252 }
253 check(ts, err, neg)
254 stdin.Write([]byte(string(r))) // nolint: errcheck
255
256 // Wait for the UI to process the input
257 time.Sleep(100 * time.Millisecond)
258 }
259 }()
260
261 err = sess.RequestPty("dumb", 40, 80, ssh.TerminalModes{})
262 check(ts, err, neg)
263 check(ts, sess.Run(""), neg)
264 }
265}
266
267// P.S. Windows sucks!
268func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
269 if neg {
270 ts.Fatalf("unsupported: ! dos2unix")
271 }
272 if len(args) < 1 {
273 ts.Fatalf("usage: dos2unix paths...")
274 }
275 for _, arg := range args {
276 filename := ts.MkAbs(arg)
277 data, err := os.ReadFile(filename)
278 if err != nil {
279 ts.Fatalf("%s: %v", filename, err)
280 }
281
282 // Replace all '\r\n' with '\n'.
283 data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
284
285 if err := os.WriteFile(filename, data, 0o644); err != nil {
286 ts.Fatalf("%s: %v", filename, err)
287 }
288 }
289}
290
291var sshConfig = `
292Host *
293 UserKnownHostsFile %q
294 StrictHostKeyChecking no
295 IdentityAgent none
296 IdentitiesOnly yes
297 ServerAliveInterval 60
298`
299
300func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
301 return func(ts *testscript.TestScript, neg bool, args []string) {
302 ts.Check(os.WriteFile(
303 ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
304 []byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
305 0o600,
306 ))
307 sshArgs := []string{
308 "-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
309 "-i", filepath.ToSlash(key),
310 }
311 ts.Setenv(
312 "GIT_SSH_COMMAND",
313 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
314 )
315 // Disable git prompting for credentials.
316 ts.Setenv("GIT_TERMINAL_PROMPT", "0")
317 args = append([]string{
318 "-c", "user.email=john@example.com",
319 "-c", "user.name=John Doe",
320 }, args...)
321 check(ts, ts.Exec("git", args...), neg)
322 }
323}
324
325func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
326 if len(args) < 2 {
327 ts.Fatalf("usage: mkfile path content")
328 }
329 check(ts, os.WriteFile(
330 ts.MkAbs(args[0]),
331 []byte(strings.Join(args[1:], " ")),
332 0o644,
333 ), neg)
334}
335
336func check(ts *testscript.TestScript, err error, neg bool) {
337 if neg && err == nil {
338 ts.Fatalf("expected error, got nil")
339 }
340 if !neg {
341 ts.Check(err)
342 }
343}
344
345func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
346 ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
347}
348
349func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {
350 if len(args) < 1 {
351 ts.Fatalf("usage: envfile key=file...")
352 }
353
354 for _, arg := range args {
355 parts := strings.SplitN(arg, "=", 2)
356 if len(parts) != 2 {
357 ts.Fatalf("usage: envfile key=file...")
358 }
359 key := parts[0]
360 file := parts[1]
361 ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
362 }
363}
364
365func cmdNewWebhook(ts *testscript.TestScript, neg bool, args []string) {
366 type webhookSite struct {
367 UUID string `json:"uuid"`
368 }
369
370 if len(args) != 1 {
371 ts.Fatalf("usage: new-webhook <env-name>")
372 }
373
374 const whSite = "https://webhook.site"
375 req, err := http.NewRequest(http.MethodPost, whSite+"/token", nil)
376 check(ts, err, neg)
377
378 resp, err := http.DefaultClient.Do(req)
379 check(ts, err, neg)
380
381 defer resp.Body.Close()
382 var site webhookSite
383 check(ts, json.NewDecoder(resp.Body).Decode(&site), neg)
384
385 ts.Setenv(args[0], whSite+"/"+site.UUID)
386}
387
388func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {
389 var verbose bool
390 var headers []string
391 var data string
392 method := http.MethodGet
393
394 cmd := &cobra.Command{
395 Use: "curl",
396 Args: cobra.MinimumNArgs(1),
397 RunE: func(cmd *cobra.Command, args []string) error {
398 url, err := url.Parse(args[0])
399 if err != nil {
400 return err
401 }
402
403 req, err := http.NewRequest(method, url.String(), nil)
404 if err != nil {
405 return err
406 }
407
408 if data != "" {
409 req.Body = io.NopCloser(strings.NewReader(data))
410 }
411
412 if verbose {
413 fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())
414 }
415
416 for _, header := range headers {
417 parts := strings.SplitN(header, ":", 2)
418 if len(parts) != 2 {
419 return fmt.Errorf("invalid header: %s", header)
420 }
421 req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
422 }
423
424 if userInfo := url.User; userInfo != nil {
425 password, _ := userInfo.Password()
426 req.SetBasicAuth(userInfo.Username(), password)
427 }
428
429 if verbose {
430 for key, values := range req.Header {
431 for _, value := range values {
432 fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)
433 }
434 }
435 }
436
437 resp, err := http.DefaultClient.Do(req)
438 if err != nil {
439 return err
440 }
441
442 if verbose {
443 fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)
444 for key, values := range resp.Header {
445 for _, value := range values {
446 fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)
447 }
448 }
449 }
450
451 defer resp.Body.Close()
452 buf, err := io.ReadAll(resp.Body)
453 if err != nil {
454 return err
455 }
456
457 cmd.Print(string(buf))
458
459 return nil
460 },
461 }
462
463 cmd.SetArgs(args)
464 cmd.SetOut(ts.Stdout())
465 cmd.SetErr(ts.Stderr())
466
467 cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")
468 cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")
469 cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")
470 cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")
471
472 check(ts, cmd.Execute(), neg)
473}
474
475func cmdWaitforserver(ts *testscript.TestScript, neg bool, args []string) {
476 // wait until the server is up
477 for {
478 conn, _ := net.DialTimeout(
479 "tcp",
480 net.JoinHostPort("localhost", fmt.Sprintf("%s", ts.Getenv("SSH_PORT"))),
481 time.Second,
482 )
483 if conn != nil {
484 conn.Close()
485 break
486 }
487 }
488}
489
490func cmdStopserver(ts *testscript.TestScript, neg bool, args []string) {
491 // stop the server
492 resp, err := http.DefaultClient.Head(fmt.Sprintf("%s/__stop", ts.Getenv("SOFT_SERVE_HTTP_PUBLIC_URL")))
493 check(ts, err, neg)
494 defer resp.Body.Close()
495 time.Sleep(time.Second * 2) // Allow some time for the server to stop
496}
497
498func setupPostgres(t testscript.T, cfg *config.Config) (error, func()) {
499 // Indicates postgres
500 // Create a disposable database
501 rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
502 dbName := fmt.Sprintf("softserve_test_%d", rnd.Int63())
503 dbDsn := cfg.DB.DataSource
504 if dbDsn == "" {
505 cfg.DB.DataSource = "postgres://postgres@localhost:5432/postgres?sslmode=disable"
506 }
507
508 dbUrl, err := url.Parse(cfg.DB.DataSource)
509 if err != nil {
510 return err, nil
511 }
512
513 scheme := dbUrl.Scheme
514 if scheme == "" {
515 scheme = "postgres"
516 }
517
518 host := dbUrl.Hostname()
519 if host == "" {
520 host = "localhost"
521 }
522
523 connInfo := fmt.Sprintf("host=%s sslmode=disable", host)
524 username := dbUrl.User.Username()
525 if username != "" {
526 connInfo += fmt.Sprintf(" user=%s", username)
527 password, ok := dbUrl.User.Password()
528 if ok {
529 username = fmt.Sprintf("%s:%s", username, password)
530 connInfo += fmt.Sprintf(" password=%s", password)
531 }
532 username = fmt.Sprintf("%s@", username)
533 } else {
534 connInfo += " user=postgres"
535 username = "postgres@"
536 }
537
538 port := dbUrl.Port()
539 if port != "" {
540 connInfo += fmt.Sprintf(" port=%s", port)
541 port = fmt.Sprintf(":%s", port)
542 }
543
544 cfg.DB.DataSource = fmt.Sprintf("%s://%s%s%s/%s?sslmode=disable",
545 scheme,
546 username,
547 host,
548 port,
549 dbName,
550 )
551
552 // Create the database
553 dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
554 if err != nil {
555 return err, nil
556 }
557
558 if _, err := dbx.Exec("CREATE DATABASE " + dbName); err != nil {
559 return err, nil
560 }
561
562 return nil, func() {
563 dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
564 if err != nil {
565 t.Fatal("failed to open database", dbName, err)
566 }
567
568 if _, err := dbx.Exec("DROP DATABASE " + dbName); err != nil {
569 t.Fatal("failed to drop database", dbName, err)
570 }
571 }
572}