1package testscript
2
3import (
4 "bytes"
5 "context"
6 "flag"
7 "fmt"
8 "io"
9 "net"
10 "net/http"
11 "net/url"
12 "os"
13 "path/filepath"
14 "strings"
15 "sync"
16 "testing"
17 "time"
18
19 "github.com/charmbracelet/keygen"
20 "github.com/charmbracelet/log"
21 "github.com/charmbracelet/soft-serve/server"
22 "github.com/charmbracelet/soft-serve/server/backend"
23 "github.com/charmbracelet/soft-serve/server/config"
24 "github.com/charmbracelet/soft-serve/server/db"
25 "github.com/charmbracelet/soft-serve/server/db/migrate"
26 logr "github.com/charmbracelet/soft-serve/server/log"
27 "github.com/charmbracelet/soft-serve/server/store"
28 "github.com/charmbracelet/soft-serve/server/store/database"
29 "github.com/charmbracelet/soft-serve/server/test"
30 "github.com/rogpeppe/go-internal/testscript"
31 "github.com/spf13/cobra"
32 "golang.org/x/crypto/ssh"
33 _ "modernc.org/sqlite" // sqlite Driver
34)
35
36var update = flag.Bool("update", false, "update script files")
37
38func TestScript(t *testing.T) {
39 flag.Parse()
40 var lock sync.Mutex
41
42 mkkey := func(name string) (string, *keygen.SSHKeyPair) {
43 path := filepath.Join(t.TempDir(), name)
44 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
45 if err != nil {
46 t.Fatal(err)
47 }
48 return path, pair
49 }
50
51 key, admin1 := mkkey("admin1")
52 _, admin2 := mkkey("admin2")
53 _, user1 := mkkey("user1")
54
55 testscript.Run(t, testscript.Params{
56 Dir: "./testdata/",
57 UpdateScripts: *update,
58 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
59 "soft": cmdSoft(admin1.Signer()),
60 "usoft": cmdSoft(user1.Signer()),
61 "git": cmdGit(key),
62 "curl": cmdCurl,
63 "mkfile": cmdMkfile,
64 "envfile": cmdEnvfile,
65 "readfile": cmdReadfile,
66 "dos2unix": cmdDos2Unix,
67 },
68 Setup: func(e *testscript.Env) error {
69 data := t.TempDir()
70
71 sshPort := test.RandomPort()
72 sshListen := fmt.Sprintf("localhost:%d", sshPort)
73 gitPort := test.RandomPort()
74 gitListen := fmt.Sprintf("localhost:%d", gitPort)
75 httpPort := test.RandomPort()
76 httpListen := fmt.Sprintf("localhost:%d", httpPort)
77 statsPort := test.RandomPort()
78 statsListen := fmt.Sprintf("localhost:%d", statsPort)
79 serverName := "Test Soft Serve"
80
81 e.Setenv("DATA_PATH", data)
82 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
83 e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
84 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
85 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
86 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
87 e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
88 e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
89
90 cfg := config.DefaultConfig()
91 cfg.DataPath = data
92 cfg.Name = serverName
93 cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
94 cfg.SSH.ListenAddr = sshListen
95 cfg.SSH.PublicURL = "ssh://" + sshListen
96 cfg.Git.ListenAddr = gitListen
97 cfg.HTTP.ListenAddr = httpListen
98 cfg.HTTP.PublicURL = "http://" + httpListen
99 cfg.Stats.ListenAddr = statsListen
100 cfg.DB.Driver = "sqlite"
101 cfg.LFS.Enabled = true
102 // TODO: run tests with both SSH enabled/disabled
103 cfg.LFS.SSHEnabled = false
104
105 if err := cfg.Validate(); err != nil {
106 return err
107 }
108
109 ctx := config.WithContext(context.Background(), cfg)
110
111 logger, f, err := logr.NewLogger(cfg)
112 if err != nil {
113 log.Errorf("failed to create logger: %v", err)
114 }
115
116 ctx = log.WithContext(ctx, logger)
117 if f != nil {
118 defer f.Close() // nolint: errcheck
119 }
120
121 // TODO: test postgres
122 dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
123 if err != nil {
124 return fmt.Errorf("open database: %w", err)
125 }
126
127 if err := migrate.Migrate(ctx, dbx); err != nil {
128 return fmt.Errorf("migrate database: %w", err)
129 }
130
131 ctx = db.WithContext(ctx, dbx)
132 datastore := database.New(ctx, dbx)
133 ctx = store.WithContext(ctx, datastore)
134 be := backend.New(ctx, cfg, dbx)
135 ctx = backend.WithContext(ctx, be)
136
137 // prevent race condition in lipgloss...
138 // this will probably be autofixed when we start using the colors
139 // from the ssh session instead of the server.
140 // XXX: take another look at this soon
141 lock.Lock()
142 srv, err := server.NewServer(ctx)
143 if err != nil {
144 return err
145 }
146 lock.Unlock()
147
148 go func() {
149 if err := srv.Start(); err != nil {
150 e.T().Fatal(err)
151 }
152 }()
153
154 e.Defer(func() {
155 defer dbx.Close() // nolint: errcheck
156 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
157 defer cancel()
158 if err := srv.Shutdown(ctx); err != nil {
159 e.T().Fatal(err)
160 }
161 })
162
163 // wait until the server is up
164 for {
165 conn, _ := net.DialTimeout(
166 "tcp",
167 net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
168 time.Second,
169 )
170 if conn != nil {
171 conn.Close()
172 break
173 }
174 }
175
176 return nil
177 },
178 })
179}
180
181func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
182 return func(ts *testscript.TestScript, neg bool, args []string) {
183 cli, err := ssh.Dial(
184 "tcp",
185 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
186 &ssh.ClientConfig{
187 User: "admin",
188 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
189 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
190 },
191 )
192 ts.Check(err)
193 defer cli.Close()
194
195 sess, err := cli.NewSession()
196 ts.Check(err)
197 defer sess.Close()
198
199 sess.Stdout = ts.Stdout()
200 sess.Stderr = ts.Stderr()
201
202 check(ts, sess.Run(strings.Join(args, " ")), neg)
203 }
204}
205
206// P.S. Windows sucks!
207func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
208 if neg {
209 ts.Fatalf("unsupported: ! dos2unix")
210 }
211 if len(args) < 1 {
212 ts.Fatalf("usage: dos2unix paths...")
213 }
214 for _, arg := range args {
215 filename := ts.MkAbs(arg)
216 data, err := os.ReadFile(filename)
217 if err != nil {
218 ts.Fatalf("%s: %v", filename, err)
219 }
220
221 // Replace all '\r\n' with '\n'.
222 data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
223
224 if err := os.WriteFile(filename, data, 0o644); err != nil {
225 ts.Fatalf("%s: %v", filename, err)
226 }
227 }
228}
229
230var sshConfig = `
231Host *
232 UserKnownHostsFile %q
233 StrictHostKeyChecking no
234 IdentityAgent none
235 IdentitiesOnly yes
236 ServerAliveInterval 60
237`
238
239func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
240 return func(ts *testscript.TestScript, neg bool, args []string) {
241 ts.Check(os.WriteFile(
242 ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
243 []byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
244 0o600,
245 ))
246 sshArgs := []string{
247 "-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
248 "-i", filepath.ToSlash(key),
249 }
250 ts.Setenv(
251 "GIT_SSH_COMMAND",
252 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
253 )
254 // Disable git prompting for credentials.
255 ts.Setenv("GIT_TERMINAL_PROMPT", "0")
256 args = append([]string{
257 "-c", "user.email=john@example.com",
258 "-c", "user.name=John Doe",
259 }, args...)
260 check(ts, ts.Exec("git", args...), neg)
261 }
262}
263
264func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
265 if len(args) < 2 {
266 ts.Fatalf("usage: mkfile path content")
267 }
268 check(ts, os.WriteFile(
269 ts.MkAbs(args[0]),
270 []byte(strings.Join(args[1:], " ")),
271 0o644,
272 ), neg)
273}
274
275func check(ts *testscript.TestScript, err error, neg bool) {
276 if neg && err == nil {
277 ts.Fatalf("expected error, got nil")
278 }
279 if !neg {
280 ts.Check(err)
281 }
282}
283
284func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
285 ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
286}
287
288func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {
289 if len(args) < 1 {
290 ts.Fatalf("usage: envfile key=file...")
291 }
292
293 for _, arg := range args {
294 parts := strings.SplitN(arg, "=", 2)
295 if len(parts) != 2 {
296 ts.Fatalf("usage: envfile key=file...")
297 }
298 key := parts[0]
299 file := parts[1]
300 ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
301 }
302}
303
304func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {
305 var verbose bool
306 var headers []string
307 var data string
308 method := http.MethodGet
309
310 cmd := &cobra.Command{
311 Use: "curl",
312 Args: cobra.MinimumNArgs(1),
313 RunE: func(cmd *cobra.Command, args []string) error {
314 url, err := url.Parse(args[0])
315 if err != nil {
316 return err
317 }
318
319 req, err := http.NewRequest(method, url.String(), nil)
320 if err != nil {
321 return err
322 }
323
324 if data != "" {
325 req.Body = io.NopCloser(strings.NewReader(data))
326 }
327
328 if verbose {
329 fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())
330 }
331
332 for _, header := range headers {
333 parts := strings.SplitN(header, ":", 2)
334 if len(parts) != 2 {
335 return fmt.Errorf("invalid header: %s", header)
336 }
337 req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
338 }
339
340 if userInfo := url.User; userInfo != nil {
341 password, _ := userInfo.Password()
342 req.SetBasicAuth(userInfo.Username(), password)
343 }
344
345 if verbose {
346 for key, values := range req.Header {
347 for _, value := range values {
348 fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)
349 }
350 }
351 }
352
353 resp, err := http.DefaultClient.Do(req)
354 if err != nil {
355 return err
356 }
357
358 if verbose {
359 fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)
360 for key, values := range resp.Header {
361 for _, value := range values {
362 fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)
363 }
364 }
365 }
366
367 defer resp.Body.Close()
368 buf, err := io.ReadAll(resp.Body)
369 if err != nil {
370 return err
371 }
372
373 cmd.Print(string(buf))
374
375 return nil
376 },
377 }
378
379 cmd.SetArgs(args)
380 cmd.SetOut(ts.Stdout())
381 cmd.SetErr(ts.Stderr())
382
383 cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")
384 cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")
385 cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")
386 cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")
387
388 check(ts, cmd.Execute(), neg)
389}