1package testscript
2
3import (
4 "bytes"
5 "context"
6 "flag"
7 "fmt"
8 "net"
9 "os"
10 "path/filepath"
11 "strings"
12 "sync"
13 "testing"
14 "time"
15
16 "github.com/charmbracelet/keygen"
17 "github.com/charmbracelet/log"
18 "github.com/charmbracelet/soft-serve/server"
19 "github.com/charmbracelet/soft-serve/server/backend"
20 "github.com/charmbracelet/soft-serve/server/config"
21 "github.com/charmbracelet/soft-serve/server/db"
22 "github.com/charmbracelet/soft-serve/server/db/migrate"
23 logr "github.com/charmbracelet/soft-serve/server/log"
24 "github.com/charmbracelet/soft-serve/server/store"
25 "github.com/charmbracelet/soft-serve/server/store/database"
26 "github.com/charmbracelet/soft-serve/server/test"
27 "github.com/rogpeppe/go-internal/testscript"
28 "golang.org/x/crypto/ssh"
29 _ "modernc.org/sqlite" // sqlite Driver
30)
31
32var update = flag.Bool("update", false, "update script files")
33
34func TestScript(t *testing.T) {
35 flag.Parse()
36 var lock sync.Mutex
37
38 mkkey := func(name string) (string, *keygen.SSHKeyPair) {
39 path := filepath.Join(t.TempDir(), name)
40 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
41 if err != nil {
42 t.Fatal(err)
43 }
44 return path, pair
45 }
46
47 key, admin1 := mkkey("admin1")
48 _, admin2 := mkkey("admin2")
49 _, user1 := mkkey("user1")
50
51 testscript.Run(t, testscript.Params{
52 Dir: "./testdata/",
53 UpdateScripts: *update,
54 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
55 "soft": cmdSoft(admin1.Signer()),
56 "usoft": cmdSoft(user1.Signer()),
57 "git": cmdGit(key),
58 "mkfile": cmdMkfile,
59 "envfile": cmdEnvfile,
60 "readfile": cmdReadfile,
61 "dos2unix": cmdDos2Unix,
62 },
63 Setup: func(e *testscript.Env) error {
64 data := t.TempDir()
65
66 sshPort := test.RandomPort()
67 sshListen := fmt.Sprintf("localhost:%d", sshPort)
68 gitPort := test.RandomPort()
69 gitListen := fmt.Sprintf("localhost:%d", gitPort)
70 httpPort := test.RandomPort()
71 httpListen := fmt.Sprintf("localhost:%d", httpPort)
72 statsPort := test.RandomPort()
73 statsListen := fmt.Sprintf("localhost:%d", statsPort)
74 serverName := "Test Soft Serve"
75
76 e.Setenv("DATA_PATH", data)
77 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
78 e.Setenv("HTTP_PORT", fmt.Sprintf("%d", httpPort))
79 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
80 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
81 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
82 e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
83 e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
84
85 cfg := config.DefaultConfig()
86 cfg.DataPath = data
87 cfg.Name = serverName
88 cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
89 cfg.SSH.ListenAddr = sshListen
90 cfg.SSH.PublicURL = "ssh://" + sshListen
91 cfg.Git.ListenAddr = gitListen
92 cfg.HTTP.ListenAddr = httpListen
93 cfg.HTTP.PublicURL = "http://" + httpListen
94 cfg.Stats.ListenAddr = statsListen
95 cfg.DB.Driver = "sqlite"
96 cfg.LFS.Enabled = true
97 // TODO: run tests with both SSH enabled/disabled
98 cfg.LFS.SSHEnabled = false
99
100 if err := cfg.Validate(); err != nil {
101 return err
102 }
103
104 ctx := config.WithContext(context.Background(), cfg)
105
106 logger, f, err := logr.NewLogger(cfg)
107 if err != nil {
108 log.Errorf("failed to create logger: %v", err)
109 }
110
111 ctx = log.WithContext(ctx, logger)
112 if f != nil {
113 defer f.Close() // nolint: errcheck
114 }
115
116 // TODO: test postgres
117 dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
118 if err != nil {
119 return fmt.Errorf("open database: %w", err)
120 }
121
122 if err := migrate.Migrate(ctx, dbx); err != nil {
123 return fmt.Errorf("migrate database: %w", err)
124 }
125
126 ctx = db.WithContext(ctx, dbx)
127 datastore := database.New(ctx, dbx)
128 ctx = store.WithContext(ctx, datastore)
129 be := backend.New(ctx, cfg, dbx)
130 ctx = backend.WithContext(ctx, be)
131
132 // prevent race condition in lipgloss...
133 // this will probably be autofixed when we start using the colors
134 // from the ssh session instead of the server.
135 // XXX: take another look at this soon
136 lock.Lock()
137 srv, err := server.NewServer(ctx)
138 if err != nil {
139 return err
140 }
141 lock.Unlock()
142
143 go func() {
144 if err := srv.Start(); err != nil {
145 e.T().Fatal(err)
146 }
147 }()
148
149 e.Defer(func() {
150 defer dbx.Close() // nolint: errcheck
151 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
152 defer cancel()
153 if err := srv.Shutdown(ctx); err != nil {
154 e.T().Fatal(err)
155 }
156 })
157
158 // wait until the server is up
159 for {
160 conn, _ := net.DialTimeout(
161 "tcp",
162 net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
163 time.Second,
164 )
165 if conn != nil {
166 conn.Close()
167 break
168 }
169 }
170
171 return nil
172 },
173 })
174}
175
176func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
177 return func(ts *testscript.TestScript, neg bool, args []string) {
178 cli, err := ssh.Dial(
179 "tcp",
180 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
181 &ssh.ClientConfig{
182 User: "admin",
183 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
184 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
185 },
186 )
187 ts.Check(err)
188 defer cli.Close()
189
190 sess, err := cli.NewSession()
191 ts.Check(err)
192 defer sess.Close()
193
194 sess.Stdout = ts.Stdout()
195 sess.Stderr = ts.Stderr()
196
197 check(ts, sess.Run(strings.Join(args, " ")), neg)
198 }
199}
200
201// P.S. Windows sucks!
202func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
203 if neg {
204 ts.Fatalf("unsupported: ! dos2unix")
205 }
206 if len(args) < 1 {
207 ts.Fatalf("usage: dos2unix paths...")
208 }
209 for _, arg := range args {
210 filename := ts.MkAbs(arg)
211 data, err := os.ReadFile(filename)
212 if err != nil {
213 ts.Fatalf("%s: %v", filename, err)
214 }
215
216 // Replace all '\r\n' with '\n'.
217 data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
218
219 if err := os.WriteFile(filename, data, 0o644); err != nil {
220 ts.Fatalf("%s: %v", filename, err)
221 }
222 }
223}
224
225var sshConfig = `
226Host *
227 UserKnownHostsFile %q
228 StrictHostKeyChecking no
229 IdentityAgent none
230 IdentitiesOnly yes
231 ServerAliveInterval 60
232`
233
234func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
235 return func(ts *testscript.TestScript, neg bool, args []string) {
236 ts.Check(os.WriteFile(
237 ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
238 []byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
239 0o600,
240 ))
241 sshArgs := []string{
242 "-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
243 "-i", filepath.ToSlash(key),
244 }
245 ts.Setenv(
246 "GIT_SSH_COMMAND",
247 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
248 )
249 // Disable git prompting for credentials.
250 ts.Setenv("GIT_TERMINAL_PROMPT", "0")
251 args = append([]string{
252 "-c", "user.email=john@example.com",
253 "-c", "user.name=John Doe",
254 }, args...)
255 check(ts, ts.Exec("git", args...), neg)
256 }
257}
258
259func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
260 if len(args) < 2 {
261 ts.Fatalf("usage: mkfile path content")
262 }
263 check(ts, os.WriteFile(
264 ts.MkAbs(args[0]),
265 []byte(strings.Join(args[1:], " ")),
266 0o644,
267 ), neg)
268}
269
270func check(ts *testscript.TestScript, err error, neg bool) {
271 if neg && err == nil {
272 ts.Fatalf("expected error, got nil")
273 }
274 if !neg {
275 ts.Check(err)
276 }
277}
278
279func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
280 ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
281}
282
283func cmdEnvfile(ts *testscript.TestScript, neg bool, args []string) {
284 if len(args) < 1 {
285 ts.Fatalf("usage: envfile key=file...")
286 }
287
288 for _, arg := range args {
289 parts := strings.SplitN(arg, "=", 2)
290 if len(parts) != 2 {
291 ts.Fatalf("usage: envfile key=file...")
292 }
293 key := parts[0]
294 file := parts[1]
295 ts.Setenv(key, strings.TrimSpace(ts.ReadFile(file)))
296 }
297}