1package testscript
  2
  3import (
  4	"bytes"
  5	"context"
  6	"flag"
  7	"fmt"
  8	"io"
  9	"net"
 10	"net/http"
 11	"net/url"
 12	"os"
 13	"path/filepath"
 14	"strings"
 15	"sync"
 16	"testing"
 17	"time"
 18
 19	"github.com/charmbracelet/keygen"
 20	"github.com/charmbracelet/log"
 21	"github.com/charmbracelet/soft-serve/server"
 22	"github.com/charmbracelet/soft-serve/server/backend"
 23	"github.com/charmbracelet/soft-serve/server/config"
 24	"github.com/charmbracelet/soft-serve/server/db"
 25	"github.com/charmbracelet/soft-serve/server/db/migrate"
 26	logr "github.com/charmbracelet/soft-serve/server/log"
 27	"github.com/charmbracelet/soft-serve/server/store"
 28	"github.com/charmbracelet/soft-serve/server/store/database"
 29	"github.com/charmbracelet/soft-serve/server/test"
 30	"github.com/rogpeppe/go-internal/testscript"
 31	"github.com/spf13/cobra"
 32	"golang.org/x/crypto/ssh"
 33	_ "modernc.org/sqlite" // sqlite Driver
 34)
 35
 36var update = flag.Bool("update", false, "update script files")
 37
 38func TestScript(t *testing.T) {
 39	flag.Parse()
 40	var lock sync.Mutex
 41
 42	mkkey := func(name string) (string, *keygen.SSHKeyPair) {
 43		path := filepath.Join(t.TempDir(), name)
 44		pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
 45		if err != nil {
 46			t.Fatal(err)
 47		}
 48		return path, pair
 49	}
 50
 51	key, admin1 := mkkey("admin1")
 52	_, admin2 := mkkey("admin2")
 53	_, user1 := mkkey("user1")
 54
 55	testscript.Run(t, testscript.Params{
 56		Dir:           "./testdata/",
 57		UpdateScripts: *update,
 58		Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
 59			"soft":     cmdSoft(admin1.Signer()),
 60			"usoft":    cmdSoft(user1.Signer()),
 61			"git":      cmdGit(key),
 62			"curl":     cmdCurl,
 63			"mkfile":   cmdMkfile,
 64			"envfile":  cmdEnvfile,
 65			"readfile": cmdReadfile,
 66			"dos2unix": cmdDos2Unix,
 67		},
 68		Setup: func(e *testscript.Env) error {
 69			data := t.TempDir()
 70
 71			sshPort := test.RandomPort()
 72			sshListen := fmt.Sprintf("localhost:%d", sshPort)
 73			gitPort := test.RandomPort()
 74			gitListen := fmt.Sprintf("localhost:%d", gitPort)
 75			httpPort := test.RandomPort()
 76			httpListen := fmt.Sprintf("localhost:%d", httpPort)
 77			statsPort := test.RandomPort()
 78			statsListen := fmt.Sprintf("localhost:%d", statsPort)
 79			serverName := "Test Soft Serve"
 80
 81			e.Setenv("DATA_PATH", data)
 82			e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
 83			e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
 84			e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
 85			e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
 86			e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
 87			e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
 88			e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
 89
 90			cfg := config.DefaultConfig()
 91			cfg.DataPath = data
 92			cfg.Name = serverName
 93			cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
 94			cfg.SSH.ListenAddr = sshListen
 95			cfg.SSH.PublicURL = "ssh://" + sshListen
 96			cfg.Git.ListenAddr = gitListen
 97			cfg.HTTP.ListenAddr = httpListen
 98			cfg.HTTP.PublicURL = "http://" + httpListen
 99			cfg.Stats.ListenAddr = statsListen
100			cfg.DB.Driver = "sqlite"
101			cfg.LFS.Enabled = true
102			// TODO: run tests with both SSH enabled/disabled
103			cfg.LFS.SSHEnabled = false
104
105			if err := cfg.Validate(); err != nil {
106				return err
107			}
108
109			ctx := config.WithContext(context.Background(), cfg)
110
111			logger, f, err := logr.NewLogger(cfg)
112			if err != nil {
113				log.Errorf("failed to create logger: %v", err)
114			}
115
116			ctx = log.WithContext(ctx, logger)
117			if f != nil {
118				defer f.Close() // nolint: errcheck
119			}
120
121			// TODO: test postgres
122			dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
123			if err != nil {
124				return fmt.Errorf("open database: %w", err)
125			}
126
127			if err := migrate.Migrate(ctx, dbx); err != nil {
128				return fmt.Errorf("migrate database: %w", err)
129			}
130
131			ctx = db.WithContext(ctx, dbx)
132			datastore := database.New(ctx, dbx)
133			ctx = store.WithContext(ctx, datastore)
134			be := backend.New(ctx, cfg, dbx)
135			ctx = backend.WithContext(ctx, be)
136
137			// prevent race condition in lipgloss...
138			// this will probably be autofixed when we start using the colors
139			// from the ssh session instead of the server.
140			// XXX: take another look at this soon
141			lock.Lock()
142			srv, err := server.NewServer(ctx)
143			if err != nil {
144				return err
145			}
146			lock.Unlock()
147
148			go func() {
149				if err := srv.Start(); err != nil {
150					e.T().Fatal(err)
151				}
152			}()
153
154			e.Defer(func() {
155				defer dbx.Close() // nolint: errcheck
156				ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
157				defer cancel()
158				if err := srv.Shutdown(ctx); err != nil {
159					e.T().Fatal(err)
160				}
161			})
162
163			// wait until the server is up
164			for {
165				conn, _ := net.DialTimeout(
166					"tcp",
167					net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
168					time.Second,
169				)
170				if conn != nil {
171					conn.Close()
172					break
173				}
174			}
175
176			return nil
177		},
178	})
179}
180
181func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
182	return func(ts *testscript.TestScript, neg bool, args []string) {
183		cli, err := ssh.Dial(
184			"tcp",
185			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
186			&ssh.ClientConfig{
187				User:            "admin",
188				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
189				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
190			},
191		)
192		ts.Check(err)
193		defer cli.Close()
194
195		sess, err := cli.NewSession()
196		ts.Check(err)
197		defer sess.Close()
198
199		sess.Stdout = ts.Stdout()
200		sess.Stderr = ts.Stderr()
201
202		check(ts, sess.Run(strings.Join(args, " ")), neg)
203	}
204}
205
206// P.S. Windows sucks!
207func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
208	if neg {
209		ts.Fatalf("unsupported: ! dos2unix")
210	}
211	if len(args) < 1 {
212		ts.Fatalf("usage: dos2unix paths...")
213	}
214	for _, arg := range args {
215		filename := ts.MkAbs(arg)
216		data, err := os.ReadFile(filename)
217		if err != nil {
218			ts.Fatalf("%s: %v", filename, err)
219		}
220
221		// Replace all '\r\n' with '\n'.
222		data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
223
224		if err := os.WriteFile(filename, data, 0o644); err != nil {
225			ts.Fatalf("%s: %v", filename, err)
226		}
227	}
228}
229
230var sshConfig = `
231Host *
232  UserKnownHostsFile %q
233  StrictHostKeyChecking no
234  IdentityAgent none
235  IdentitiesOnly yes
236  ServerAliveInterval 60
237`
238
239func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
240	return func(ts *testscript.TestScript, neg bool, args []string) {
241		ts.Check(os.WriteFile(
242			ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
243			[]byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
244			0o600,
245		))
246		sshArgs := []string{
247			"-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
248			"-i", filepath.ToSlash(key),
249		}
250		ts.Setenv(
251			"GIT_SSH_COMMAND",
252			strings.Join(append([]string{"ssh"}, sshArgs...), " "),
253		)
254		// Disable git prompting for credentials.
255		ts.Setenv("GIT_TERMINAL_PROMPT", "0")
256		args = append([]string{
257			"-c", "user.email=john@example.com",
258			"-c", "user.name=John Doe",
259		}, args...)
260		check(ts, ts.Exec("git", args...), neg)
261	}
262}
263
264func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
265	if len(args) < 2 {
266		ts.Fatalf("usage: mkfile path content")
267	}
268	check(ts, os.WriteFile(
269		ts.MkAbs(args[0]),
270		[]byte(strings.Join(args[1:], " ")),
271		0o644,
272	), neg)
273}
274
275func check(ts *testscript.TestScript, err error, neg bool) {
276	if neg && err == nil {
277		ts.Fatalf("expected error, got nil")
278	}
279	if !neg {
280		ts.Check(err)
281	}
282}
283
284func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
285	ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
286}
287
288func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {
289	if len(args) < 1 {
290		ts.Fatalf("usage: envfile key=file...")
291	}
292
293	for _, arg := range args {
294		parts := strings.SplitN(arg, "=", 2)
295		if len(parts) != 2 {
296			ts.Fatalf("usage: envfile key=file...")
297		}
298		key := parts[0]
299		file := parts[1]
300		ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
301	}
302}
303
304func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {
305	var verbose bool
306	var headers []string
307	var data string
308	method := http.MethodGet
309
310	cmd := &cobra.Command{
311		Use:  "curl",
312		Args: cobra.MinimumNArgs(1),
313		RunE: func(cmd *cobra.Command, args []string) error {
314			url, err := url.Parse(args[0])
315			if err != nil {
316				return err
317			}
318
319			req, err := http.NewRequest(method, url.String(), nil)
320			if err != nil {
321				return err
322			}
323
324			if data != "" {
325				req.Body = io.NopCloser(strings.NewReader(data))
326			}
327
328			if verbose {
329				fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())
330			}
331
332			for _, header := range headers {
333				parts := strings.SplitN(header, ":", 2)
334				if len(parts) != 2 {
335					return fmt.Errorf("invalid header: %s", header)
336				}
337				req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
338			}
339
340			if userInfo := url.User; userInfo != nil {
341				password, _ := userInfo.Password()
342				req.SetBasicAuth(userInfo.Username(), password)
343			}
344
345			if verbose {
346				for key, values := range req.Header {
347					for _, value := range values {
348						fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)
349					}
350				}
351			}
352
353			resp, err := http.DefaultClient.Do(req)
354			if err != nil {
355				return err
356			}
357
358			if verbose {
359				fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)
360				for key, values := range resp.Header {
361					for _, value := range values {
362						fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)
363					}
364				}
365			}
366
367			defer resp.Body.Close()
368			buf, err := io.ReadAll(resp.Body)
369			if err != nil {
370				return err
371			}
372
373			cmd.Print(string(buf))
374
375			return nil
376		},
377	}
378
379	cmd.SetArgs(args)
380	cmd.SetOut(ts.Stdout())
381	cmd.SetErr(ts.Stderr())
382
383	cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")
384	cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")
385	cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")
386	cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")
387
388	check(ts, cmd.Execute(), neg)
389}