1package testscript
2
3import (
4 "bytes"
5 "context"
6 "database/sql"
7 "flag"
8 "fmt"
9 "io"
10 "net"
11 "net/http"
12 "net/url"
13 "os"
14 "path/filepath"
15 "strings"
16 "sync"
17 "testing"
18 "time"
19
20 "github.com/charmbracelet/keygen"
21 "github.com/charmbracelet/log"
22 "github.com/charmbracelet/soft-serve/cmd/soft/serve"
23 "github.com/charmbracelet/soft-serve/pkg/backend"
24 "github.com/charmbracelet/soft-serve/pkg/config"
25 "github.com/charmbracelet/soft-serve/pkg/db"
26 "github.com/charmbracelet/soft-serve/pkg/db/migrate"
27 logr "github.com/charmbracelet/soft-serve/pkg/log"
28 "github.com/charmbracelet/soft-serve/pkg/store"
29 "github.com/charmbracelet/soft-serve/pkg/store/database"
30 "github.com/charmbracelet/soft-serve/pkg/test"
31 "github.com/rogpeppe/go-internal/testscript"
32 "github.com/spf13/cobra"
33 "golang.org/x/crypto/ssh"
34)
35
36var update = flag.Bool("update", false, "update script files")
37
38func TestScript(t *testing.T) {
39 flag.Parse()
40 var lock sync.Mutex
41
42 mkkey := func(name string) (string, *keygen.SSHKeyPair) {
43 path := filepath.Join(t.TempDir(), name)
44 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
45 if err != nil {
46 t.Fatal(err)
47 }
48 return path, pair
49 }
50
51 key, admin1 := mkkey("admin1")
52 _, admin2 := mkkey("admin2")
53 _, user1 := mkkey("user1")
54
55 testscript.Run(t, testscript.Params{
56 Dir: "./testdata/",
57 UpdateScripts: *update,
58 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
59 "soft": cmdSoft(admin1.Signer()),
60 "usoft": cmdSoft(user1.Signer()),
61 "git": cmdGit(key),
62 "curl": cmdCurl,
63 "mkfile": cmdMkfile,
64 "envfile": cmdEnvfile,
65 "readfile": cmdReadfile,
66 "dos2unix": cmdDos2Unix,
67 },
68 Setup: func(e *testscript.Env) error {
69 data := t.TempDir()
70
71 sshPort := test.RandomPort()
72 sshListen := fmt.Sprintf("localhost:%d", sshPort)
73 gitPort := test.RandomPort()
74 gitListen := fmt.Sprintf("localhost:%d", gitPort)
75 httpPort := test.RandomPort()
76 httpListen := fmt.Sprintf("localhost:%d", httpPort)
77 statsPort := test.RandomPort()
78 statsListen := fmt.Sprintf("localhost:%d", statsPort)
79 serverName := "Test Soft Serve"
80
81 e.Setenv("DATA_PATH", data)
82 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
83 e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
84 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
85 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
86 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
87 e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
88 e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
89
90 cfg := config.DefaultConfig()
91 cfg.DataPath = data
92 cfg.Name = serverName
93 cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
94 cfg.SSH.ListenAddr = sshListen
95 cfg.SSH.PublicURL = "ssh://" + sshListen
96 cfg.Git.ListenAddr = gitListen
97 cfg.HTTP.ListenAddr = httpListen
98 cfg.HTTP.PublicURL = "http://" + httpListen
99 cfg.Stats.ListenAddr = statsListen
100 cfg.DB.Driver = "sqlite"
101 cfg.LFS.Enabled = true
102 cfg.LFS.SSHEnabled = true
103
104 dbDriver := os.Getenv("DB_DRIVER")
105 if dbDriver != "" {
106 cfg.DB.Driver = dbDriver
107 }
108
109 dbDsn := os.Getenv("DB_DATA_SOURCE")
110 if dbDsn != "" {
111 cfg.DB.DataSource = dbDsn
112 }
113
114 if cfg.DB.Driver == "postgres" {
115 err, cleanup := setupPostgres(e.T(), cfg)
116 if err != nil {
117 return err
118 }
119 if cleanup != nil {
120 e.Defer(cleanup)
121 }
122 }
123
124 if err := cfg.Validate(); err != nil {
125 return err
126 }
127
128 ctx := config.WithContext(context.Background(), cfg)
129
130 logger, f, err := logr.NewLogger(cfg)
131 if err != nil {
132 log.Errorf("failed to create logger: %v", err)
133 }
134
135 ctx = log.WithContext(ctx, logger)
136 if f != nil {
137 defer f.Close() // nolint: errcheck
138 }
139
140 dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
141 if err != nil {
142 return fmt.Errorf("open database: %w", err)
143 }
144
145 if err := migrate.Migrate(ctx, dbx); err != nil {
146 return fmt.Errorf("migrate database: %w", err)
147 }
148
149 ctx = db.WithContext(ctx, dbx)
150 datastore := database.New(ctx, dbx)
151 ctx = store.WithContext(ctx, datastore)
152 be := backend.New(ctx, cfg, dbx)
153 ctx = backend.WithContext(ctx, be)
154
155 lock.Lock()
156 srv, err := serve.NewServer(ctx)
157 if err != nil {
158 lock.Unlock()
159 return err
160 }
161 lock.Unlock()
162
163 go func() {
164 if err := srv.Start(); err != nil {
165 e.T().Fatal(err)
166 }
167 }()
168
169 e.Defer(func() {
170 defer dbx.Close() // nolint: errcheck
171 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
172 defer cancel()
173 lock.Lock()
174 defer lock.Unlock()
175 if err := srv.Shutdown(ctx); err != nil {
176 e.T().Fatal(err)
177 }
178 })
179
180 // wait until the server is up
181 for {
182 conn, _ := net.DialTimeout(
183 "tcp",
184 net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
185 time.Second,
186 )
187 if conn != nil {
188 conn.Close()
189 break
190 }
191 }
192
193 return nil
194 },
195 })
196}
197
198func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
199 return func(ts *testscript.TestScript, neg bool, args []string) {
200 cli, err := ssh.Dial(
201 "tcp",
202 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
203 &ssh.ClientConfig{
204 User: "admin",
205 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
206 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
207 },
208 )
209 ts.Check(err)
210 defer cli.Close()
211
212 sess, err := cli.NewSession()
213 ts.Check(err)
214 defer sess.Close()
215
216 sess.Stdout = ts.Stdout()
217 sess.Stderr = ts.Stderr()
218
219 check(ts, sess.Run(strings.Join(args, " ")), neg)
220 }
221}
222
223// P.S. Windows sucks!
224func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
225 if neg {
226 ts.Fatalf("unsupported: ! dos2unix")
227 }
228 if len(args) < 1 {
229 ts.Fatalf("usage: dos2unix paths...")
230 }
231 for _, arg := range args {
232 filename := ts.MkAbs(arg)
233 data, err := os.ReadFile(filename)
234 if err != nil {
235 ts.Fatalf("%s: %v", filename, err)
236 }
237
238 // Replace all '\r\n' with '\n'.
239 data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
240
241 if err := os.WriteFile(filename, data, 0o644); err != nil {
242 ts.Fatalf("%s: %v", filename, err)
243 }
244 }
245}
246
247var sshConfig = `
248Host *
249 UserKnownHostsFile %q
250 StrictHostKeyChecking no
251 IdentityAgent none
252 IdentitiesOnly yes
253 ServerAliveInterval 60
254`
255
256func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
257 return func(ts *testscript.TestScript, neg bool, args []string) {
258 ts.Check(os.WriteFile(
259 ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
260 []byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
261 0o600,
262 ))
263 sshArgs := []string{
264 "-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
265 "-i", filepath.ToSlash(key),
266 }
267 ts.Setenv(
268 "GIT_SSH_COMMAND",
269 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
270 )
271 // Disable git prompting for credentials.
272 ts.Setenv("GIT_TERMINAL_PROMPT", "0")
273 args = append([]string{
274 "-c", "user.email=john@example.com",
275 "-c", "user.name=John Doe",
276 }, args...)
277 check(ts, ts.Exec("git", args...), neg)
278 }
279}
280
281func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
282 if len(args) < 2 {
283 ts.Fatalf("usage: mkfile path content")
284 }
285 check(ts, os.WriteFile(
286 ts.MkAbs(args[0]),
287 []byte(strings.Join(args[1:], " ")),
288 0o644,
289 ), neg)
290}
291
292func check(ts *testscript.TestScript, err error, neg bool) {
293 if neg && err == nil {
294 ts.Fatalf("expected error, got nil")
295 }
296 if !neg {
297 ts.Check(err)
298 }
299}
300
301func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
302 ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
303}
304
305func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {
306 if len(args) < 1 {
307 ts.Fatalf("usage: envfile key=file...")
308 }
309
310 for _, arg := range args {
311 parts := strings.SplitN(arg, "=", 2)
312 if len(parts) != 2 {
313 ts.Fatalf("usage: envfile key=file...")
314 }
315 key := parts[0]
316 file := parts[1]
317 ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
318 }
319}
320
321func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {
322 var verbose bool
323 var headers []string
324 var data string
325 method := http.MethodGet
326
327 cmd := &cobra.Command{
328 Use: "curl",
329 Args: cobra.MinimumNArgs(1),
330 RunE: func(cmd *cobra.Command, args []string) error {
331 url, err := url.Parse(args[0])
332 if err != nil {
333 return err
334 }
335
336 req, err := http.NewRequest(method, url.String(), nil)
337 if err != nil {
338 return err
339 }
340
341 if data != "" {
342 req.Body = io.NopCloser(strings.NewReader(data))
343 }
344
345 if verbose {
346 fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())
347 }
348
349 for _, header := range headers {
350 parts := strings.SplitN(header, ":", 2)
351 if len(parts) != 2 {
352 return fmt.Errorf("invalid header: %s", header)
353 }
354 req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
355 }
356
357 if userInfo := url.User; userInfo != nil {
358 password, _ := userInfo.Password()
359 req.SetBasicAuth(userInfo.Username(), password)
360 }
361
362 if verbose {
363 for key, values := range req.Header {
364 for _, value := range values {
365 fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)
366 }
367 }
368 }
369
370 resp, err := http.DefaultClient.Do(req)
371 if err != nil {
372 return err
373 }
374
375 if verbose {
376 fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)
377 for key, values := range resp.Header {
378 for _, value := range values {
379 fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)
380 }
381 }
382 }
383
384 defer resp.Body.Close()
385 buf, err := io.ReadAll(resp.Body)
386 if err != nil {
387 return err
388 }
389
390 cmd.Print(string(buf))
391
392 return nil
393 },
394 }
395
396 cmd.SetArgs(args)
397 cmd.SetOut(ts.Stdout())
398 cmd.SetErr(ts.Stderr())
399
400 cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")
401 cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")
402 cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")
403 cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")
404
405 check(ts, cmd.Execute(), neg)
406}
407
408func setupPostgres(t testscript.T, cfg *config.Config) (error, func()) {
409 // Indicates postgres
410 // Create a disposable database
411 dbName := fmt.Sprintf("softserve_test_%d", time.Now().UnixNano())
412 dbDsn := os.Getenv("DB_DATA_SOURCE")
413 if dbDsn == "" {
414 cfg.DB.DataSource = "postgres://postgres@localhost:5432/postgres?sslmode=disable"
415 }
416
417 dbUrl, err := url.Parse(cfg.DB.DataSource)
418 if err != nil {
419 return err, nil
420 }
421
422 connInfo := fmt.Sprintf("host=%s sslmode=disable", dbUrl.Hostname())
423 username := dbUrl.User.Username()
424 if username != "" {
425 connInfo += fmt.Sprintf(" user=%s", username)
426 password, ok := dbUrl.User.Password()
427 if ok {
428 username = fmt.Sprintf("%s:%s", username, password)
429 connInfo += fmt.Sprintf(" password=%s", password)
430 }
431 username = fmt.Sprintf("%s@", username)
432 } else {
433 connInfo += " user=postgres"
434 }
435
436 port := dbUrl.Port()
437 if port != "" {
438 connInfo += fmt.Sprintf(" port=%s", port)
439 port = fmt.Sprintf(":%s", port)
440 }
441
442 cfg.DB.DataSource = fmt.Sprintf("%s://%s%s%s/%s?sslmode=disable",
443 dbUrl.Scheme,
444 username,
445 dbUrl.Hostname(),
446 port,
447 dbName,
448 )
449
450 // Create the database
451 db, err := sql.Open(cfg.DB.Driver, connInfo)
452 if err != nil {
453 return err, nil
454 }
455
456 if _, err := db.Exec("CREATE DATABASE " + dbName); err != nil {
457 return err, nil
458 }
459
460 return nil, func() {
461 db, err := sql.Open(cfg.DB.Driver, connInfo)
462 if err != nil {
463 t.Log("failed to open database", dbName, err)
464 return
465 }
466
467 if _, err := db.Exec("DROP DATABASE " + dbName); err != nil {
468 t.Log("failed to drop database", dbName, err)
469 }
470 }
471}