1package testscript
2
3import (
4 "bytes"
5 "context"
6 "flag"
7 "fmt"
8 "io"
9 "net"
10 "net/http"
11 "net/url"
12 "os"
13 "path/filepath"
14 "strings"
15 "sync"
16 "testing"
17 "time"
18
19 "github.com/charmbracelet/keygen"
20 "github.com/charmbracelet/log"
21 "github.com/charmbracelet/soft-serve/server"
22 "github.com/charmbracelet/soft-serve/server/backend"
23 "github.com/charmbracelet/soft-serve/server/config"
24 "github.com/charmbracelet/soft-serve/server/db"
25 "github.com/charmbracelet/soft-serve/server/db/migrate"
26 logr "github.com/charmbracelet/soft-serve/server/log"
27 "github.com/charmbracelet/soft-serve/server/store"
28 "github.com/charmbracelet/soft-serve/server/store/database"
29 "github.com/charmbracelet/soft-serve/server/test"
30 "github.com/rogpeppe/go-internal/testscript"
31 "github.com/spf13/cobra"
32 "golang.org/x/crypto/ssh"
33 _ "modernc.org/sqlite" // sqlite Driver
34)
35
36var update = flag.Bool("update", false, "update script files")
37
38func TestScript(t *testing.T) {
39 flag.Parse()
40 var lock sync.Mutex
41
42 mkkey := func(name string) (string, *keygen.SSHKeyPair) {
43 path := filepath.Join(t.TempDir(), name)
44 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
45 if err != nil {
46 t.Fatal(err)
47 }
48 return path, pair
49 }
50
51 key, admin1 := mkkey("admin1")
52 _, admin2 := mkkey("admin2")
53 _, user1 := mkkey("user1")
54
55 testscript.Run(t, testscript.Params{
56 Dir: "./testdata/",
57 UpdateScripts: *update,
58 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
59 "soft": cmdSoft(admin1.Signer()),
60 "usoft": cmdSoft(user1.Signer()),
61 "git": cmdGit(key),
62 "curl": cmdCurl,
63 "mkfile": cmdMkfile,
64 "envfile": cmdEnvfile,
65 "readfile": cmdReadfile,
66 "dos2unix": cmdDos2Unix,
67 },
68 Setup: func(e *testscript.Env) error {
69 data := t.TempDir()
70
71 sshPort := test.RandomPort()
72 sshListen := fmt.Sprintf("localhost:%d", sshPort)
73 gitPort := test.RandomPort()
74 gitListen := fmt.Sprintf("localhost:%d", gitPort)
75 httpPort := test.RandomPort()
76 httpListen := fmt.Sprintf("localhost:%d", httpPort)
77 statsPort := test.RandomPort()
78 statsListen := fmt.Sprintf("localhost:%d", statsPort)
79 serverName := "Test Soft Serve"
80
81 e.Setenv("DATA_PATH", data)
82 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
83 e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
84 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
85 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
86 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
87 e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
88 e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
89
90 cfg := config.DefaultConfig()
91 cfg.DataPath = data
92 cfg.Name = serverName
93 cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
94 cfg.SSH.ListenAddr = sshListen
95 cfg.SSH.PublicURL = "ssh://" + sshListen
96 cfg.Git.ListenAddr = gitListen
97 cfg.HTTP.ListenAddr = httpListen
98 cfg.HTTP.PublicURL = "http://" + httpListen
99 cfg.Stats.ListenAddr = statsListen
100 cfg.DB.Driver = "sqlite"
101 cfg.LFS.Enabled = true
102 cfg.LFS.SSHEnabled = true
103
104 if err := cfg.Validate(); err != nil {
105 return err
106 }
107
108 ctx := config.WithContext(context.Background(), cfg)
109
110 logger, f, err := logr.NewLogger(cfg)
111 if err != nil {
112 log.Errorf("failed to create logger: %v", err)
113 }
114
115 ctx = log.WithContext(ctx, logger)
116 if f != nil {
117 defer f.Close() // nolint: errcheck
118 }
119
120 // TODO: test postgres
121 dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
122 if err != nil {
123 return fmt.Errorf("open database: %w", err)
124 }
125
126 if err := migrate.Migrate(ctx, dbx); err != nil {
127 return fmt.Errorf("migrate database: %w", err)
128 }
129
130 ctx = db.WithContext(ctx, dbx)
131 datastore := database.New(ctx, dbx)
132 ctx = store.WithContext(ctx, datastore)
133 be := backend.New(ctx, cfg, dbx)
134 ctx = backend.WithContext(ctx, be)
135
136 lock.Lock()
137 srv, err := server.NewServer(ctx)
138 if err != nil {
139 lock.Unlock()
140 return err
141 }
142 lock.Unlock()
143
144 go func() {
145 if err := srv.Start(); err != nil {
146 e.T().Fatal(err)
147 }
148 }()
149
150 e.Defer(func() {
151 defer dbx.Close() // nolint: errcheck
152 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
153 defer cancel()
154 lock.Lock()
155 defer lock.Unlock()
156 if err := srv.Shutdown(ctx); err != nil {
157 e.T().Fatal(err)
158 }
159 })
160
161 // wait until the server is up
162 for {
163 conn, _ := net.DialTimeout(
164 "tcp",
165 net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
166 time.Second,
167 )
168 if conn != nil {
169 conn.Close()
170 break
171 }
172 }
173
174 return nil
175 },
176 })
177}
178
179func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
180 return func(ts *testscript.TestScript, neg bool, args []string) {
181 cli, err := ssh.Dial(
182 "tcp",
183 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
184 &ssh.ClientConfig{
185 User: "admin",
186 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
187 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
188 },
189 )
190 ts.Check(err)
191 defer cli.Close()
192
193 sess, err := cli.NewSession()
194 ts.Check(err)
195 defer sess.Close()
196
197 sess.Stdout = ts.Stdout()
198 sess.Stderr = ts.Stderr()
199
200 check(ts, sess.Run(strings.Join(args, " ")), neg)
201 }
202}
203
204// P.S. Windows sucks!
205func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
206 if neg {
207 ts.Fatalf("unsupported: ! dos2unix")
208 }
209 if len(args) < 1 {
210 ts.Fatalf("usage: dos2unix paths...")
211 }
212 for _, arg := range args {
213 filename := ts.MkAbs(arg)
214 data, err := os.ReadFile(filename)
215 if err != nil {
216 ts.Fatalf("%s: %v", filename, err)
217 }
218
219 // Replace all '\r\n' with '\n'.
220 data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
221
222 if err := os.WriteFile(filename, data, 0o644); err != nil {
223 ts.Fatalf("%s: %v", filename, err)
224 }
225 }
226}
227
228var sshConfig = `
229Host *
230 UserKnownHostsFile %q
231 StrictHostKeyChecking no
232 IdentityAgent none
233 IdentitiesOnly yes
234 ServerAliveInterval 60
235`
236
237func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
238 return func(ts *testscript.TestScript, neg bool, args []string) {
239 ts.Check(os.WriteFile(
240 ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
241 []byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
242 0o600,
243 ))
244 sshArgs := []string{
245 "-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
246 "-i", filepath.ToSlash(key),
247 }
248 ts.Setenv(
249 "GIT_SSH_COMMAND",
250 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
251 )
252 // Disable git prompting for credentials.
253 ts.Setenv("GIT_TERMINAL_PROMPT", "0")
254 args = append([]string{
255 "-c", "user.email=john@example.com",
256 "-c", "user.name=John Doe",
257 }, args...)
258 check(ts, ts.Exec("git", args...), neg)
259 }
260}
261
262func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
263 if len(args) < 2 {
264 ts.Fatalf("usage: mkfile path content")
265 }
266 check(ts, os.WriteFile(
267 ts.MkAbs(args[0]),
268 []byte(strings.Join(args[1:], " ")),
269 0o644,
270 ), neg)
271}
272
273func check(ts *testscript.TestScript, err error, neg bool) {
274 if neg && err == nil {
275 ts.Fatalf("expected error, got nil")
276 }
277 if !neg {
278 ts.Check(err)
279 }
280}
281
282func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
283 ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
284}
285
286func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {
287 if len(args) < 1 {
288 ts.Fatalf("usage: envfile key=file...")
289 }
290
291 for _, arg := range args {
292 parts := strings.SplitN(arg, "=", 2)
293 if len(parts) != 2 {
294 ts.Fatalf("usage: envfile key=file...")
295 }
296 key := parts[0]
297 file := parts[1]
298 ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
299 }
300}
301
302func cmdCurl(ts *testscript.TestScript, neg bool, args []string) {
303 var verbose bool
304 var headers []string
305 var data string
306 method := http.MethodGet
307
308 cmd := &cobra.Command{
309 Use: "curl",
310 Args: cobra.MinimumNArgs(1),
311 RunE: func(cmd *cobra.Command, args []string) error {
312 url, err := url.Parse(args[0])
313 if err != nil {
314 return err
315 }
316
317 req, err := http.NewRequest(method, url.String(), nil)
318 if err != nil {
319 return err
320 }
321
322 if data != "" {
323 req.Body = io.NopCloser(strings.NewReader(data))
324 }
325
326 if verbose {
327 fmt.Fprintf(cmd.ErrOrStderr(), "< %s %s\n", req.Method, url.String())
328 }
329
330 for _, header := range headers {
331 parts := strings.SplitN(header, ":", 2)
332 if len(parts) != 2 {
333 return fmt.Errorf("invalid header: %s", header)
334 }
335 req.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
336 }
337
338 if userInfo := url.User; userInfo != nil {
339 password, _ := userInfo.Password()
340 req.SetBasicAuth(userInfo.Username(), password)
341 }
342
343 if verbose {
344 for key, values := range req.Header {
345 for _, value := range values {
346 fmt.Fprintf(cmd.ErrOrStderr(), "< %s: %s\n", key, value)
347 }
348 }
349 }
350
351 resp, err := http.DefaultClient.Do(req)
352 if err != nil {
353 return err
354 }
355
356 if verbose {
357 fmt.Fprintf(ts.Stderr(), "> %s\n", resp.Status)
358 for key, values := range resp.Header {
359 for _, value := range values {
360 fmt.Fprintf(cmd.ErrOrStderr(), "> %s: %s\n", key, value)
361 }
362 }
363 }
364
365 defer resp.Body.Close()
366 buf, err := io.ReadAll(resp.Body)
367 if err != nil {
368 return err
369 }
370
371 cmd.Print(string(buf))
372
373 return nil
374 },
375 }
376
377 cmd.SetArgs(args)
378 cmd.SetOut(ts.Stdout())
379 cmd.SetErr(ts.Stderr())
380
381 cmd.Flags().BoolVarP(&verbose, "verbose", "v", verbose, "verbose")
382 cmd.Flags().StringArrayVarP(&headers, "header", "H", nil, "HTTP header")
383 cmd.Flags().StringVarP(&method, "request", "X", method, "HTTP method")
384 cmd.Flags().StringVarP(&data, "data", "d", data, "HTTP data")
385
386 check(ts, cmd.Execute(), neg)
387}