1package testscript
  2
  3import (
  4	"bytes"
  5	"context"
  6	"database/sql"
  7	"flag"
  8	"fmt"
  9	"io"
 10	"net"
 11	"net/http"
 12	"net/url"
 13	"os"
 14	"path/filepath"
 15	"strings"
 16	"sync"
 17	"testing"
 18	"time"
 19
 20	"github.com/charmbracelet/keygen"
 21	"github.com/charmbracelet/log"
 22	"github.com/charmbracelet/soft-serve/server"
 23	"github.com/charmbracelet/soft-serve/server/backend"
 24	"github.com/charmbracelet/soft-serve/server/config"
 25	"github.com/charmbracelet/soft-serve/server/db"
 26	"github.com/charmbracelet/soft-serve/server/db/migrate"
 27	logr "github.com/charmbracelet/soft-serve/server/log"
 28	"github.com/charmbracelet/soft-serve/server/store"
 29	"github.com/charmbracelet/soft-serve/server/store/database"
 30	"github.com/charmbracelet/soft-serve/server/test"
 31	"github.com/rogpeppe/go-internal/testscript"
 32	"github.com/spf13/cobra"
 33	"golang.org/x/crypto/ssh"
 34)
 35
 36var update = flag.Bool("update", false, "update script files")
 37
 38func TestScript(t *testing.T) {
 39	flag.Parse()
 40	var lock sync.Mutex
 41
 42	mkkey := func(name string) (string, *keygen.SSHKeyPair) {
 43		path := filepath.Join(t.TempDir(), name)
 44		pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
 45		if err != nil {
 46			t.Fatal(err)
 47		}
 48		return path, pair
 49	}
 50
 51	key, admin1 := mkkey("admin1")
 52	_, admin2 := mkkey("admin2")
 53	_, user1 := mkkey("user1")
 54
 55	testscript.Run(t, testscript.Params{
 56		Dir:           "./testdata/",
 57		UpdateScripts: *update,
 58		Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
 59			"soft":     cmdSoft(admin1.Signer()),
 60			"usoft":    cmdSoft(user1.Signer()),
 61			"git":      cmdGit(key),
 62			"curl":     cmdCurl,
 63			"mkfile":   cmdMkfile,
 64			"envfile":  cmdEnvfile,
 65			"readfile": cmdReadfile,
 66			"dos2unix": cmdDos2Unix,
 67		},
 68		Setup: func(e *testscript.Env) error {
 69			data := t.TempDir()
 70
 71			sshPort := test.RandomPort()
 72			sshListen := fmt.Sprintf("localhost:%d", sshPort)
 73			gitPort := test.RandomPort()
 74			gitListen := fmt.Sprintf("localhost:%d", gitPort)
 75			httpPort := test.RandomPort()
 76			httpListen := fmt.Sprintf("localhost:%d", httpPort)
 77			statsPort := test.RandomPort()
 78			statsListen := fmt.Sprintf("localhost:%d", statsPort)
 79			serverName := "Test Soft Serve"
 80
 81			e.Setenv("DATA_PATH", data)
 82			e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
 83			e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
 84			e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
 85			e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
 86			e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
 87			e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
 88			e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
 89
 90			cfg := config.DefaultConfig()
 91			cfg.DataPath = data
 92			cfg.Name = serverName
 93			cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
 94			cfg.SSH.ListenAddr = sshListen
 95			cfg.SSH.PublicURL = "ssh://" + sshListen
 96			cfg.Git.ListenAddr = gitListen
 97			cfg.HTTP.ListenAddr = httpListen
 98			cfg.HTTP.PublicURL = "http://" + httpListen
 99			cfg.Stats.ListenAddr = statsListen
100			cfg.DB.Driver = "sqlite"
101			cfg.LFS.Enabled = true
102			cfg.LFS.SSHEnabled = true
103
104			dbDriver := os.Getenv("DB_DRIVER")
105			if dbDriver != "" {
106				cfg.DB.Driver = dbDriver
107			}
108
109			dbDsn := os.Getenv("DB_DATA_SOURCE")
110			if dbDsn != "" {
111				cfg.DB.DataSource = dbDsn
112			}
113
114			if cfg.DB.Driver == "postgres" {
115				err, cleanup := setupPostgres(e.T(), cfg)
116				if err != nil {
117					return err
118				}
119				if cleanup != nil {
120					e.Defer(cleanup)
121				}
122			}
123
124			if err := cfg.Validate(); err != nil {
125				return err
126			}
127
128			ctx := config.WithContext(context.Background(), cfg)
129
130			logger, f, err := logr.NewLogger(cfg)
131			if err != nil {
132				log.Errorf("failed to create logger: %v", err)
133			}
134
135			ctx = log.WithContext(ctx, logger)
136			if f != nil {
137				defer f.Close() // nolint: errcheck
138			}
139
140			dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
141			if err != nil {
142				return fmt.Errorf("open database: %w", err)
143			}
144
145			if err := migrate.Migrate(ctx, dbx); err != nil {
146				return fmt.Errorf("migrate database: %w", err)
147			}
148
149			ctx = db.WithContext(ctx, dbx)
150			datastore := database.New(ctx, dbx)
151			ctx = store.WithContext(ctx, datastore)
152			be := backend.New(ctx, cfg, dbx)
153			ctx = backend.WithContext(ctx, be)
154
155			lock.Lock()
156			srv, err := server.NewServer(ctx)
157			if err != nil {
158				lock.Unlock()
159				return err
160			}
161			lock.Unlock()
162
163			go func() {
164				if err := srv.Start(); err != nil {
165					e.T().Fatal(err)
166				}
167			}()
168
169			e.Defer(func() {
170				defer dbx.Close() // nolint: errcheck
171				ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
172				defer cancel()
173				lock.Lock()
174				defer lock.Unlock()
175				if err := srv.Shutdown(ctx); err != nil {
176					e.T().Fatal(err)
177				}
178			})
179
180			// wait until the server is up
181			for {
182				conn, _ := net.DialTimeout(
183					"tcp",
184					net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
185					time.Second,
186				)
187				if conn != nil {
188					conn.Close()
189					break
190				}
191			}
192
193			return nil
194		},
195	})
196}
197
198func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
199	return func(ts *testscript.TestScript, neg bool, args []string) {
200		cli, err := ssh.Dial(
201			"tcp",
202			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
203			&ssh.ClientConfig{
204				User:            "admin",
205				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
206				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
207			},
208		)
209		ts.Check(err)
210		defer cli.Close()
211
212		sess, err := cli.NewSession()
213		ts.Check(err)
214		defer sess.Close()
215
216		sess.Stdout = ts.Stdout()
217		sess.Stderr = ts.Stderr()
218
219		check(ts, sess.Run(strings.Join(args, " ")), neg)
220	}
221}
222
223// P.S. Windows sucks!
224func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
225	if neg {
226		ts.Fatalf("unsupported: ! dos2unix")
227	}
228	if len(args) < 1 {
229		ts.Fatalf("usage: dos2unix paths...")
230	}
231	for _, arg := range args {
232		filename := ts.MkAbs(arg)
233		data, err := os.ReadFile(filename)
234		if err != nil {
235			ts.Fatalf("%s: %v", filename, err)
236		}
237
238		// Replace all '\r\n' with '\n'.
239		data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
240
241		if err := os.WriteFile(filename, data, 0o644); err != nil {
242			ts.Fatalf("%s: %v", filename, err)
243		}
244	}
245}
246
247var sshConfig = `
248Host *
249  UserKnownHostsFile %q
250  StrictHostKeyChecking no
251  IdentityAgent none
252  IdentitiesOnly yes
253  ServerAliveInterval 60
254`
255
256func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
257	return func(ts *testscript.TestScript, neg bool, args []string) {
258		ts.Check(os.WriteFile(
259			ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
260			[]byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
261			0o600,
262		))
263		sshArgs := []string{
264			"-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
265			"-i", filepath.ToSlash(key),
266		}
267		ts.Setenv(
268			"GIT_SSH_COMMAND",
269			strings.Join(append([]string{"ssh"}, sshArgs...), " "),
270		)
271		// Disable git prompting for credentials.
272		ts.Setenv("GIT_TERMINAL_PROMPT", "0")
273		args = append([]string{
274			"-c", "user.email=john@example.com",
275			"-c", "user.name=John Doe",
276		}, args...)
277		check(ts, ts.Exec("git", args...), neg)
278	}
279}
280
281func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
282	if len(args) < 2 {
283		ts.Fatalf("usage: mkfile path content")
284	}
285	check(ts, os.WriteFile(
286		ts.MkAbs(args[0]),
287		[]byte(strings.Join(args[1:], " ")),
288		0o644,
289	), neg)
290}
291
292func check(ts *testscript.TestScript, err error, neg bool) {
293	if neg && err == nil {
294		ts.Fatalf("expected error, got nil")
295	}
296	if !neg {
297		ts.Check(err)
298	}
299}
300
301func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
302	ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
303}
304
305func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {
306	if len(args) < 1 {
307		ts.Fatalf("usage: envfile key=file...")
308	}
309
310	for _, arg := range args {
311		parts := strings.SplitN(arg, "=", 2)
312		if len(parts) != 2 {
313			ts.Fatalf("usage: envfile key=file...")
314		}
315		key := parts[0]
316		file := parts[1]
317		ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
318	}
319}
320
321func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {
322	var verbose bool
323	var headers []string
324	var data string
325	method := http.MethodGet
326
327	cmd := &cobra.Command{
328		Use:  "curl",
329		Args: cobra.MinimumNArgs(1),
330		RunE: func(cmd *cobra.Command, args []string) error {
331			url, err := url.Parse(args[0])
332			if err != nil {
333				return err
334			}
335
336			req, err := http.NewRequest(method, url.String(), nil)
337			if err != nil {
338				return err
339			}
340
341			if data != "" {
342				req.Body = io.NopCloser(strings.NewReader(data))
343			}
344
345			if verbose {
346				fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())
347			}
348
349			for _, header := range headers {
350				parts := strings.SplitN(header, ":", 2)
351				if len(parts) != 2 {
352					return fmt.Errorf("invalid header: %s", header)
353				}
354				req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
355			}
356
357			if userInfo := url.User; userInfo != nil {
358				password, _ := userInfo.Password()
359				req.SetBasicAuth(userInfo.Username(), password)
360			}
361
362			if verbose {
363				for key, values := range req.Header {
364					for _, value := range values {
365						fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)
366					}
367				}
368			}
369
370			resp, err := http.DefaultClient.Do(req)
371			if err != nil {
372				return err
373			}
374
375			if verbose {
376				fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)
377				for key, values := range resp.Header {
378					for _, value := range values {
379						fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)
380					}
381				}
382			}
383
384			defer resp.Body.Close()
385			buf, err := io.ReadAll(resp.Body)
386			if err != nil {
387				return err
388			}
389
390			cmd.Print(string(buf))
391
392			return nil
393		},
394	}
395
396	cmd.SetArgs(args)
397	cmd.SetOut(ts.Stdout())
398	cmd.SetErr(ts.Stderr())
399
400	cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")
401	cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")
402	cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")
403	cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")
404
405	check(ts, cmd.Execute(), neg)
406}
407
408func setupPostgres(t testscript.T, cfg *config.Config) (error, func()) {
409	// Indicates postgres
410	// Create a disposable database
411	dbName := fmt.Sprintf("softserve_test_%d", time.Now().UnixNano())
412	dbDsn := os.Getenv("DB_DATA_SOURCE")
413	if dbDsn == "" {
414		cfg.DB.DataSource = "postgres://postgres@localhost:5432/postgres?sslmode=disable"
415	}
416
417	dbUrl, err := url.Parse(cfg.DB.DataSource)
418	if err != nil {
419		return err, nil
420	}
421
422	connInfo := fmt.Sprintf("host=%s sslmode=disable", dbUrl.Hostname())
423	username := dbUrl.User.Username()
424	if username != "" {
425		connInfo += fmt.Sprintf(" user=%s", username)
426		password, ok := dbUrl.User.Password()
427		if ok {
428			username = fmt.Sprintf("%s:%s", username, password)
429			connInfo += fmt.Sprintf(" password=%s", password)
430		}
431		username = fmt.Sprintf("%s@", username)
432	} else {
433		connInfo += " user=postgres"
434	}
435
436	port := dbUrl.Port()
437	if port != "" {
438		connInfo += fmt.Sprintf(" port=%s", port)
439		port = fmt.Sprintf(":%s", port)
440	}
441
442	cfg.DB.DataSource = fmt.Sprintf("%s://%s%s%s/%s?sslmode=disable",
443		dbUrl.Scheme,
444		username,
445		dbUrl.Hostname(),
446		port,
447		dbName,
448	)
449
450	// Create the database
451	db, err := sql.Open(cfg.DB.Driver, connInfo)
452	if err != nil {
453		return err, nil
454	}
455
456	if _, err := db.Exec("CREATE DATABASE " + dbName); err != nil {
457		return err, nil
458	}
459
460	return nil, func() {
461		db, err := sql.Open(cfg.DB.Driver, connInfo)
462		if err != nil {
463			t.Log("failed to open database", dbName, err)
464			return
465		}
466
467		if _, err := db.Exec("DROP DATABASE " + dbName); err != nil {
468			t.Log("failed to drop database", dbName, err)
469		}
470	}
471}