1package testscript
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "flag"
8 "fmt"
9 "io"
10 "math/rand"
11 "net"
12 "net/http"
13 "net/url"
14 "os"
15 "os/exec"
16 "path/filepath"
17 "runtime"
18 "strconv"
19 "strings"
20 "testing"
21 "time"
22
23 "github.com/charmbracelet/keygen"
24 "github.com/charmbracelet/soft-serve/pkg/config"
25 "github.com/charmbracelet/soft-serve/pkg/db"
26 "github.com/charmbracelet/soft-serve/pkg/test"
27 "github.com/rogpeppe/go-internal/testscript"
28 "github.com/spf13/cobra"
29 "golang.org/x/crypto/ssh"
30)
31
32var (
33 update = flag.Bool("update", false, "update script files")
34 binPath string
35)
36
37func TestMain(m *testing.M) {
38 tmp, err := os.MkdirTemp("", "soft-serve*")
39 if err != nil {
40 fmt.Fprintf(os.Stderr, "failed to create temporary directory: %s", err)
41 os.Exit(1)
42 }
43 defer os.RemoveAll(tmp)
44
45 binPath = filepath.Join(tmp, "soft")
46 if runtime.GOOS == "windows" {
47 binPath += ".exe"
48 }
49
50 // Build the soft binary with -cover flag.
51 cmd := exec.Command("go", "build", "-race", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft"))
52 if err := cmd.Run(); err != nil {
53 fmt.Fprintf(os.Stderr, "failed to build soft-serve binary: %s", err)
54 os.Exit(1)
55 }
56
57 // Run tests
58 os.Exit(m.Run())
59}
60
61func TestScript(t *testing.T) {
62 flag.Parse()
63
64 mkkey := func(name string) (string, *keygen.SSHKeyPair) {
65 path := filepath.Join(t.TempDir(), name)
66 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
67 if err != nil {
68 t.Fatal(err)
69 }
70 return path, pair
71 }
72
73 key, admin1 := mkkey("admin1")
74 _, admin2 := mkkey("admin2")
75 _, user1 := mkkey("user1")
76
77 testscript.Run(t, testscript.Params{
78 Dir: "./testdata/",
79 UpdateScripts: *update,
80 RequireExplicitExec: true,
81 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
82 "soft": cmdSoft("admin", admin1.Signer()),
83 "usoft": cmdSoft("user1", user1.Signer()),
84 "git": cmdGit(key),
85 "curl": cmdCurl,
86 "mkfile": cmdMkfile,
87 "envfile": cmdEnvfile,
88 "readfile": cmdReadfile,
89 "dos2unix": cmdDos2Unix,
90 "new-webhook": cmdNewWebhook,
91 "waitforserver": cmdWaitforserver,
92 "stopserver": cmdStopserver,
93 "ui": cmdUI(admin1.Signer()),
94 "uui": cmdUI(user1.Signer()),
95 },
96 Setup: func(e *testscript.Env) error {
97 // Add binPath to PATH
98 e.Setenv("PATH", fmt.Sprintf("%s%c%s", filepath.Dir(binPath), os.PathListSeparator, e.Getenv("PATH")))
99
100 data := t.TempDir()
101 sshPort := test.RandomPort()
102 sshListen := fmt.Sprintf("localhost:%d", sshPort)
103 gitPort := test.RandomPort()
104 gitListen := fmt.Sprintf("localhost:%d", gitPort)
105 httpPort := test.RandomPort()
106 httpListen := fmt.Sprintf("localhost:%d", httpPort)
107 statsPort := test.RandomPort()
108 statsListen := fmt.Sprintf("localhost:%d", statsPort)
109 serverName := "Test Soft Serve"
110
111 e.Setenv("DATA_PATH", data)
112 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
113 e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
114 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
115 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
116 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
117 e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
118 e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
119
120 // This is used to set up test specific configuration and http endpoints
121 e.Setenv("SOFT_SERVE_TESTRUN", "1")
122
123 // This will disable the default lipgloss renderer colors
124 e.Setenv("SOFT_SERVE_NO_COLOR", "1")
125
126 // Soft Serve debug environment variables
127 for _, env := range []string{
128 "SOFT_SERVE_DEBUG",
129 "SOFT_SERVE_VERBOSE",
130 } {
131 if v, ok := os.LookupEnv(env); ok {
132 e.Setenv(env, v)
133 }
134 }
135
136 // TODO: test different configs
137 cfg := config.DefaultConfig()
138 cfg.DataPath = data
139 cfg.Name = serverName
140 cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
141 cfg.SSH.ListenAddr = sshListen
142 cfg.SSH.PublicURL = "ssh://" + sshListen
143 cfg.Git.ListenAddr = gitListen
144 cfg.HTTP.ListenAddr = httpListen
145 cfg.HTTP.PublicURL = "http://" + httpListen
146 cfg.Stats.ListenAddr = statsListen
147 cfg.LFS.Enabled = true
148
149 // Parse os SOFT_SERVE environment variables
150 if err := cfg.ParseEnv(); err != nil {
151 return err
152 }
153
154 // Override the database data source if we're using postgres
155 // so we can create a temporary database for the tests.
156 if cfg.DB.Driver == "postgres" {
157 err, cleanup := setupPostgres(e.T(), cfg)
158 if err != nil {
159 return err
160 }
161 if cleanup != nil {
162 e.Defer(cleanup)
163 }
164 }
165
166 for _, env := range cfg.Environ() {
167 parts := strings.SplitN(env, "=", 2)
168 if len(parts) != 2 {
169 e.T().Fatal("invalid environment variable", env)
170 }
171 e.Setenv(parts[0], parts[1])
172 }
173
174 return nil
175 },
176 })
177}
178
179func cmdSoft(user string, key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
180 return func(ts *testscript.TestScript, neg bool, args []string) {
181 cli, err := ssh.Dial(
182 "tcp",
183 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
184 &ssh.ClientConfig{
185 User: user,
186 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
187 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
188 },
189 )
190 ts.Check(err)
191 defer cli.Close()
192
193 sess, err := cli.NewSession()
194 ts.Check(err)
195 defer sess.Close()
196
197 sess.Stdout = ts.Stdout()
198 sess.Stderr = ts.Stderr()
199
200 check(ts, sess.Run(strings.Join(args, " ")), neg)
201 }
202}
203
204func cmdUI(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
205 return func(ts *testscript.TestScript, neg bool, args []string) {
206 if len(args) < 1 {
207 ts.Fatalf("usage: ui <quoted string input>")
208 return
209 }
210
211 cli, err := ssh.Dial(
212 "tcp",
213 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
214 &ssh.ClientConfig{
215 User: "git",
216 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
217 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
218 },
219 )
220 check(ts, err, neg)
221 defer cli.Close()
222
223 sess, err := cli.NewSession()
224 check(ts, err, neg)
225 defer sess.Close()
226
227 // XXX: this is a hack to make the UI tests work
228 // cmp command always complains about an extra newline
229 // in the output
230 defer ts.Stdout().Write([]byte("\n"))
231
232 sess.Stdout = ts.Stdout()
233 sess.Stderr = ts.Stderr()
234
235 stdin, err := sess.StdinPipe()
236 check(ts, err, neg)
237
238 err = sess.RequestPty("dumb", 40, 80, ssh.TerminalModes{})
239 check(ts, err, neg)
240 check(ts, sess.Start(""), neg)
241
242 in, err := strconv.Unquote(args[0])
243 check(ts, err, neg)
244 reader := strings.NewReader(in)
245 go func() {
246 defer stdin.Close()
247 for {
248 r, _, err := reader.ReadRune()
249 if err == io.EOF {
250 break
251 }
252 check(ts, err, neg)
253 stdin.Write([]byte(string(r))) // nolint: errcheck
254
255 // Wait for the UI to process the input
256 time.Sleep(100 * time.Millisecond)
257 }
258 }()
259
260 check(ts, sess.Wait(), neg)
261 }
262}
263
264// P.S. Windows sucks!
265func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
266 if neg {
267 ts.Fatalf("unsupported: ! dos2unix")
268 }
269 if len(args) < 1 {
270 ts.Fatalf("usage: dos2unix paths...")
271 }
272 for _, arg := range args {
273 filename := ts.MkAbs(arg)
274 data, err := os.ReadFile(filename)
275 if err != nil {
276 ts.Fatalf("%s: %v", filename, err)
277 }
278
279 // Replace all '\r\n' with '\n'.
280 data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
281
282 if err := os.WriteFile(filename, data, 0o644); err != nil {
283 ts.Fatalf("%s: %v", filename, err)
284 }
285 }
286}
287
288var sshConfig = `
289Host *
290 UserKnownHostsFile %q
291 StrictHostKeyChecking no
292 IdentityAgent none
293 IdentitiesOnly yes
294 ServerAliveInterval 60
295`
296
297func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
298 return func(ts *testscript.TestScript, neg bool, args []string) {
299 ts.Check(os.WriteFile(
300 ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
301 []byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
302 0o600,
303 ))
304 sshArgs := []string{
305 "-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
306 "-i", filepath.ToSlash(key),
307 }
308 ts.Setenv(
309 "GIT_SSH_COMMAND",
310 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
311 )
312 // Disable git prompting for credentials.
313 ts.Setenv("GIT_TERMINAL_PROMPT", "0")
314 args = append([]string{
315 "-c", "user.email=john@example.com",
316 "-c", "user.name=John Doe",
317 }, args...)
318 check(ts, ts.Exec("git", args...), neg)
319 }
320}
321
322func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
323 if len(args) < 2 {
324 ts.Fatalf("usage: mkfile path content")
325 }
326 check(ts, os.WriteFile(
327 ts.MkAbs(args[0]),
328 []byte(strings.Join(args[1:], " ")),
329 0o644,
330 ), neg)
331}
332
333func check(ts *testscript.TestScript, err error, neg bool) {
334 if neg && err == nil {
335 ts.Fatalf("expected error, got nil")
336 }
337 if !neg {
338 ts.Check(err)
339 }
340}
341
342func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
343 ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
344}
345
346func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {
347 if len(args) < 1 {
348 ts.Fatalf("usage: envfile key=file...")
349 }
350
351 for _, arg := range args {
352 parts := strings.SplitN(arg, "=", 2)
353 if len(parts) != 2 {
354 ts.Fatalf("usage: envfile key=file...")
355 }
356 key := parts[0]
357 file := parts[1]
358 ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
359 }
360}
361
362func cmdNewWebhook(ts *testscript.TestScript, neg bool, args []string) {
363 type webhookSite struct {
364 UUID string `json:"uuid"`
365 }
366
367 if len(args) != 1 {
368 ts.Fatalf("usage: new-webhook <env-name>")
369 }
370
371 const whSite = "https://webhook.site"
372 req, err := http.NewRequest(http.MethodPost, whSite+"/token", nil)
373 check(ts, err, neg)
374
375 resp, err := http.DefaultClient.Do(req)
376 check(ts, err, neg)
377
378 defer resp.Body.Close()
379 var site webhookSite
380 check(ts, json.NewDecoder(resp.Body).Decode(&site), neg)
381
382 ts.Setenv(args[0], whSite+"/"+site.UUID)
383}
384
385func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {
386 var verbose bool
387 var headers []string
388 var data string
389 method := http.MethodGet
390
391 cmd := &cobra.Command{
392 Use: "curl",
393 Args: cobra.MinimumNArgs(1),
394 RunE: func(cmd *cobra.Command, args []string) error {
395 url, err := url.Parse(args[0])
396 if err != nil {
397 return err
398 }
399
400 req, err := http.NewRequest(method, url.String(), nil)
401 if err != nil {
402 return err
403 }
404
405 if data != "" {
406 req.Body = io.NopCloser(strings.NewReader(data))
407 }
408
409 if verbose {
410 fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())
411 }
412
413 for _, header := range headers {
414 parts := strings.SplitN(header, ":", 2)
415 if len(parts) != 2 {
416 return fmt.Errorf("invalid header: %s", header)
417 }
418 req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
419 }
420
421 if userInfo := url.User; userInfo != nil {
422 password, _ := userInfo.Password()
423 req.SetBasicAuth(userInfo.Username(), password)
424 }
425
426 if verbose {
427 for key, values := range req.Header {
428 for _, value := range values {
429 fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)
430 }
431 }
432 }
433
434 resp, err := http.DefaultClient.Do(req)
435 if err != nil {
436 return err
437 }
438
439 if verbose {
440 fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)
441 for key, values := range resp.Header {
442 for _, value := range values {
443 fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)
444 }
445 }
446 }
447
448 defer resp.Body.Close()
449 buf, err := io.ReadAll(resp.Body)
450 if err != nil {
451 return err
452 }
453
454 cmd.Print(string(buf))
455
456 return nil
457 },
458 }
459
460 cmd.SetArgs(args)
461 cmd.SetOut(ts.Stdout())
462 cmd.SetErr(ts.Stderr())
463
464 cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")
465 cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")
466 cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")
467 cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")
468
469 check(ts, cmd.Execute(), neg)
470}
471
472func cmdWaitforserver(ts *testscript.TestScript, neg bool, args []string) {
473 // wait until the server is up
474 addr := net.JoinHostPort("localhost", ts.Getenv("SSH_PORT"))
475 for {
476 conn, _ := net.DialTimeout(
477 "tcp",
478 addr,
479 time.Second,
480 )
481 if conn != nil {
482 conn.Close()
483 break
484 }
485 }
486}
487
488func cmdStopserver(ts *testscript.TestScript, neg bool, args []string) {
489 // stop the server
490 resp, err := http.DefaultClient.Head(fmt.Sprintf("%s/__stop", ts.Getenv("SOFT_SERVE_HTTP_PUBLIC_URL")))
491 check(ts, err, neg)
492 resp.Body.Close()
493 time.Sleep(time.Second * 2) // Allow some time for the server to stop
494}
495
496func setupPostgres(t testscript.T, cfg *config.Config) (error, func()) {
497 // Indicates postgres
498 // Create a disposable database
499 rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
500 dbName := fmt.Sprintf("softserve_test_%d", rnd.Int63())
501 dbDsn := cfg.DB.DataSource
502 if dbDsn == "" {
503 cfg.DB.DataSource = "postgres://postgres@localhost:5432/postgres?sslmode=disable"
504 }
505
506 dbUrl, err := url.Parse(cfg.DB.DataSource)
507 if err != nil {
508 return err, nil
509 }
510
511 scheme := dbUrl.Scheme
512 if scheme == "" {
513 scheme = "postgres"
514 }
515
516 host := dbUrl.Hostname()
517 if host == "" {
518 host = "localhost"
519 }
520
521 connInfo := fmt.Sprintf("host=%s sslmode=disable", host)
522 username := dbUrl.User.Username()
523 if username != "" {
524 connInfo += fmt.Sprintf(" user=%s", username)
525 password, ok := dbUrl.User.Password()
526 if ok {
527 username = fmt.Sprintf("%s:%s", username, password)
528 connInfo += fmt.Sprintf(" password=%s", password)
529 }
530 username = fmt.Sprintf("%s@", username)
531 } else {
532 connInfo += " user=postgres"
533 username = "postgres@"
534 }
535
536 port := dbUrl.Port()
537 if port != "" {
538 connInfo += fmt.Sprintf(" port=%s", port)
539 port = fmt.Sprintf(":%s", port)
540 }
541
542 cfg.DB.DataSource = fmt.Sprintf("%s://%s%s%s/%s?sslmode=disable",
543 scheme,
544 username,
545 host,
546 port,
547 dbName,
548 )
549
550 // Create the database
551 dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
552 if err != nil {
553 return err, nil
554 }
555
556 if _, err := dbx.Exec("CREATE DATABASE " + dbName); err != nil {
557 return err, nil
558 }
559
560 return nil, func() {
561 dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
562 if err != nil {
563 t.Fatal("failed to open database", dbName, err)
564 }
565
566 if _, err := dbx.Exec("DROP DATABASE " + dbName); err != nil {
567 t.Fatal("failed to drop database", dbName, err)
568 }
569 }
570}