1package testscript
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "flag"
8 "fmt"
9 "io"
10 "math/rand"
11 "net"
12 "net/http"
13 "net/url"
14 "os"
15 "os/exec"
16 "path/filepath"
17 "runtime"
18 "strings"
19 "testing"
20 "time"
21
22 "github.com/charmbracelet/keygen"
23 "github.com/charmbracelet/soft-serve/pkg/config"
24 "github.com/charmbracelet/soft-serve/pkg/db"
25 "github.com/charmbracelet/soft-serve/pkg/test"
26 "github.com/rogpeppe/go-internal/testscript"
27 "github.com/spf13/cobra"
28 "golang.org/x/crypto/ssh"
29)
30
31var (
32 update = flag.Bool("update", false, "update script files")
33 binPath string
34)
35
36func TestMain(m *testing.M) {
37 tmp, err := os.MkdirTemp("", "soft-serve*")
38 if err != nil {
39 fmt.Fprintf(os.Stderr, "failed to create temporary directory: %s", err)
40 os.Exit(1)
41 }
42 defer os.RemoveAll(tmp)
43
44 binPath = filepath.Join(tmp, "soft")
45 if runtime.GOOS == "windows" {
46 binPath += ".exe"
47 }
48
49 // Build the soft binary with -cover flag.
50 cmd := exec.Command("go", "build", "-race", "-cover", "-o", binPath, filepath.Join("..", "cmd", "soft"))
51 if err := cmd.Run(); err != nil {
52 fmt.Fprintf(os.Stderr, "failed to build soft-serve binary: %s", err)
53 os.Exit(1)
54 }
55
56 // Run tests
57 os.Exit(m.Run())
58
59 // Add binPath to PATH
60 os.Setenv("PATH", fmt.Sprintf("%s%c%s", os.Getenv("PATH"), os.PathListSeparator, filepath.Dir(binPath)))
61}
62
63func TestScript(t *testing.T) {
64 flag.Parse()
65
66 mkkey := func(name string) (string, *keygen.SSHKeyPair) {
67 path := filepath.Join(t.TempDir(), name)
68 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
69 if err != nil {
70 t.Fatal(err)
71 }
72 return path, pair
73 }
74
75 key, admin1 := mkkey("admin1")
76 _, admin2 := mkkey("admin2")
77 _, user1 := mkkey("user1")
78
79 testscript.Run(t, testscript.Params{
80 Dir: "./testdata/",
81 UpdateScripts: *update,
82 RequireExplicitExec: true,
83 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
84 "soft": cmdSoft(admin1.Signer()),
85 "usoft": cmdSoft(user1.Signer()),
86 "git": cmdGit(key),
87 "curl": cmdCurl,
88 "mkfile": cmdMkfile,
89 "envfile": cmdEnvfile,
90 "readfile": cmdReadfile,
91 "dos2unix": cmdDos2Unix,
92 "new-webhook": cmdNewWebhook,
93 "waitforserver": cmdWaitforserver,
94 "stopserver": cmdStopserver,
95 },
96 Setup: func(e *testscript.Env) error {
97 // Add binPath to PATH
98 e.Setenv("PATH", fmt.Sprintf("%s%c%s", filepath.Dir(binPath), os.PathListSeparator, e.Getenv("PATH")))
99
100 data := t.TempDir()
101 sshPort := test.RandomPort()
102 sshListen := fmt.Sprintf("localhost:%d", sshPort)
103 gitPort := test.RandomPort()
104 gitListen := fmt.Sprintf("localhost:%d", gitPort)
105 httpPort := test.RandomPort()
106 httpListen := fmt.Sprintf("localhost:%d", httpPort)
107 statsPort := test.RandomPort()
108 statsListen := fmt.Sprintf("localhost:%d", statsPort)
109 serverName := "Test Soft Serve"
110
111 e.Setenv("DATA_PATH", data)
112 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
113 e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
114 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
115 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
116 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
117 e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
118 e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
119
120 // This is used to set up test specific configuration and http endpoints
121 e.Setenv("SOFT_SERVE_TESTRUN", "1")
122
123 // Soft Serve debug environment variables
124 for _, env := range []string{
125 "SOFT_SERVE_DEBUG",
126 "SOFT_SERVE_VERBOSE",
127 } {
128 if v, ok := os.LookupEnv(env); ok {
129 e.Setenv(env, v)
130 }
131 }
132
133 // TODO: test different configs
134 cfg := config.DefaultConfig()
135 cfg.DataPath = data
136 cfg.Name = serverName
137 cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
138 cfg.SSH.ListenAddr = sshListen
139 cfg.SSH.PublicURL = "ssh://" + sshListen
140 cfg.Git.ListenAddr = gitListen
141 cfg.HTTP.ListenAddr = httpListen
142 cfg.HTTP.PublicURL = "http://" + httpListen
143 cfg.Stats.ListenAddr = statsListen
144 cfg.LFS.Enabled = true
145 // cfg.LFS.SSHEnabled = true
146
147 // Parse os SOFT_SERVE environment variables
148 if err := cfg.ParseEnv(); err != nil {
149 return err
150 }
151
152 // Override the database data source if we're using postgres
153 // so we can create a temporary database for the tests.
154 if cfg.DB.Driver == "postgres" {
155 err, cleanup := setupPostgres(e.T(), cfg)
156 if err != nil {
157 return err
158 }
159 if cleanup != nil {
160 e.Defer(cleanup)
161 }
162 }
163
164 for _, env := range cfg.Environ() {
165 parts := strings.SplitN(env, "=", 2)
166 if len(parts) != 2 {
167 e.T().Fatal("invalid environment variable", env)
168 }
169 e.Setenv(parts[0], parts[1])
170 }
171
172 return nil
173 },
174 })
175}
176
177func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
178 return func(ts *testscript.TestScript, neg bool, args []string) {
179 cli, err := ssh.Dial(
180 "tcp",
181 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
182 &ssh.ClientConfig{
183 User: "admin",
184 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
185 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
186 },
187 )
188 ts.Check(err)
189 defer cli.Close()
190
191 sess, err := cli.NewSession()
192 ts.Check(err)
193 defer sess.Close()
194
195 sess.Stdout = ts.Stdout()
196 sess.Stderr = ts.Stderr()
197
198 check(ts, sess.Run(strings.Join(args, " ")), neg)
199 }
200}
201
202// P.S. Windows sucks!
203func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
204 if neg {
205 ts.Fatalf("unsupported: ! dos2unix")
206 }
207 if len(args) < 1 {
208 ts.Fatalf("usage: dos2unix paths...")
209 }
210 for _, arg := range args {
211 filename := ts.MkAbs(arg)
212 data, err := os.ReadFile(filename)
213 if err != nil {
214 ts.Fatalf("%s: %v", filename, err)
215 }
216
217 // Replace all '\r\n' with '\n'.
218 data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
219
220 if err := os.WriteFile(filename, data, 0o644); err != nil {
221 ts.Fatalf("%s: %v", filename, err)
222 }
223 }
224}
225
226var sshConfig = `
227Host *
228 UserKnownHostsFile %q
229 StrictHostKeyChecking no
230 IdentityAgent none
231 IdentitiesOnly yes
232 ServerAliveInterval 60
233`
234
235func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
236 return func(ts *testscript.TestScript, neg bool, args []string) {
237 ts.Check(os.WriteFile(
238 ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
239 []byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
240 0o600,
241 ))
242 sshArgs := []string{
243 "-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
244 "-i", filepath.ToSlash(key),
245 }
246 ts.Setenv(
247 "GIT_SSH_COMMAND",
248 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
249 )
250 // Disable git prompting for credentials.
251 ts.Setenv("GIT_TERMINAL_PROMPT", "0")
252 args = append([]string{
253 "-c", "user.email=john@example.com",
254 "-c", "user.name=John Doe",
255 }, args...)
256 check(ts, ts.Exec("git", args...), neg)
257 }
258}
259
260func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
261 if len(args) < 2 {
262 ts.Fatalf("usage: mkfile path content")
263 }
264 check(ts, os.WriteFile(
265 ts.MkAbs(args[0]),
266 []byte(strings.Join(args[1:], " ")),
267 0o644,
268 ), neg)
269}
270
271func check(ts *testscript.TestScript, err error, neg bool) {
272 if neg && err == nil {
273 ts.Fatalf("expected error, got nil")
274 }
275 if !neg {
276 ts.Check(err)
277 }
278}
279
280func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
281 ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
282}
283
284func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {
285 if len(args) < 1 {
286 ts.Fatalf("usage: envfile key=file...")
287 }
288
289 for _, arg := range args {
290 parts := strings.SplitN(arg, "=", 2)
291 if len(parts) != 2 {
292 ts.Fatalf("usage: envfile key=file...")
293 }
294 key := parts[0]
295 file := parts[1]
296 ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
297 }
298}
299
300func cmdNewWebhook(ts *testscript.TestScript, neg bool, args []string) {
301 type webhookSite struct {
302 UUID string `json:"uuid"`
303 }
304
305 if len(args) != 1 {
306 ts.Fatalf("usage: new-webhook <env-name>")
307 }
308
309 const whSite = "https://webhook.site"
310 req, err := http.NewRequest(http.MethodPost, whSite+"/token", nil)
311 check(ts, err, neg)
312
313 resp, err := http.DefaultClient.Do(req)
314 check(ts, err, neg)
315
316 defer resp.Body.Close()
317 var site webhookSite
318 check(ts, json.NewDecoder(resp.Body).Decode(&site), neg)
319
320 ts.Setenv(args[0], whSite+"/"+site.UUID)
321}
322
323func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {
324 var verbose bool
325 var headers []string
326 var data string
327 method := http.MethodGet
328
329 cmd := &cobra.Command{
330 Use: "curl",
331 Args: cobra.MinimumNArgs(1),
332 RunE: func(cmd *cobra.Command, args []string) error {
333 url, err := url.Parse(args[0])
334 if err != nil {
335 return err
336 }
337
338 req, err := http.NewRequest(method, url.String(), nil)
339 if err != nil {
340 return err
341 }
342
343 if data != "" {
344 req.Body = io.NopCloser(strings.NewReader(data))
345 }
346
347 if verbose {
348 fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())
349 }
350
351 for _, header := range headers {
352 parts := strings.SplitN(header, ":", 2)
353 if len(parts) != 2 {
354 return fmt.Errorf("invalid header: %s", header)
355 }
356 req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
357 }
358
359 if userInfo := url.User; userInfo != nil {
360 password, _ := userInfo.Password()
361 req.SetBasicAuth(userInfo.Username(), password)
362 }
363
364 if verbose {
365 for key, values := range req.Header {
366 for _, value := range values {
367 fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)
368 }
369 }
370 }
371
372 resp, err := http.DefaultClient.Do(req)
373 if err != nil {
374 return err
375 }
376
377 if verbose {
378 fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)
379 for key, values := range resp.Header {
380 for _, value := range values {
381 fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)
382 }
383 }
384 }
385
386 defer resp.Body.Close()
387 buf, err := io.ReadAll(resp.Body)
388 if err != nil {
389 return err
390 }
391
392 cmd.Print(string(buf))
393
394 return nil
395 },
396 }
397
398 cmd.SetArgs(args)
399 cmd.SetOut(ts.Stdout())
400 cmd.SetErr(ts.Stderr())
401
402 cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")
403 cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")
404 cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")
405 cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")
406
407 check(ts, cmd.Execute(), neg)
408}
409
410func cmdWaitforserver(ts *testscript.TestScript, neg bool, args []string) {
411 // wait until the server is up
412 for {
413 conn, _ := net.DialTimeout(
414 "tcp",
415 net.JoinHostPort("localhost", fmt.Sprintf("%s", ts.Getenv("SSH_PORT"))),
416 time.Second,
417 )
418 if conn != nil {
419 conn.Close()
420 break
421 }
422 }
423}
424
425func cmdStopserver(ts *testscript.TestScript, neg bool, args []string) {
426 // stop the server
427 resp, err := http.DefaultClient.Head(fmt.Sprintf("%s/__stop", ts.Getenv("SOFT_SERVE_HTTP_PUBLIC_URL")))
428 check(ts, err, neg)
429 defer resp.Body.Close()
430 time.Sleep(time.Second * 2) // Allow some time for the server to stop
431}
432
433func setupPostgres(t testscript.T, cfg *config.Config) (error, func()) {
434 // Indicates postgres
435 // Create a disposable database
436 rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
437 dbName := fmt.Sprintf("softserve_test_%d", rnd.Int63())
438 dbDsn := cfg.DB.DataSource
439 if dbDsn == "" {
440 cfg.DB.DataSource = "postgres://postgres@localhost:5432/postgres?sslmode=disable"
441 }
442
443 dbUrl, err := url.Parse(cfg.DB.DataSource)
444 if err != nil {
445 return err, nil
446 }
447
448 scheme := dbUrl.Scheme
449 if scheme == "" {
450 scheme = "postgres"
451 }
452
453 host := dbUrl.Hostname()
454 if host == "" {
455 host = "localhost"
456 }
457
458 connInfo := fmt.Sprintf("host=%s sslmode=disable", host)
459 username := dbUrl.User.Username()
460 if username != "" {
461 connInfo += fmt.Sprintf(" user=%s", username)
462 password, ok := dbUrl.User.Password()
463 if ok {
464 username = fmt.Sprintf("%s:%s", username, password)
465 connInfo += fmt.Sprintf(" password=%s", password)
466 }
467 username = fmt.Sprintf("%s@", username)
468 } else {
469 connInfo += " user=postgres"
470 username = "postgres@"
471 }
472
473 port := dbUrl.Port()
474 if port != "" {
475 connInfo += fmt.Sprintf(" port=%s", port)
476 port = fmt.Sprintf(":%s", port)
477 }
478
479 cfg.DB.DataSource = fmt.Sprintf("%s://%s%s%s/%s?sslmode=disable",
480 scheme,
481 username,
482 host,
483 port,
484 dbName,
485 )
486
487 // Create the database
488 dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
489 if err != nil {
490 return err, nil
491 }
492
493 if _, err := dbx.Exec("CREATE DATABASE " + dbName); err != nil {
494 return err, nil
495 }
496
497 return nil, func() {
498 dbx, err := db.Open(context.TODO(), cfg.DB.Driver, connInfo)
499 if err != nil {
500 t.Fatal("failed to open database", dbName, err)
501 }
502
503 if _, err := dbx.Exec("DROP DATABASE " + dbName); err != nil {
504 t.Fatal("failed to drop database", dbName, err)
505 }
506 }
507}