1package testscript
2
3import (
4 "bytes"
5 "context"
6 "flag"
7 "fmt"
8 "net"
9 "os"
10 "path/filepath"
11 "strings"
12 "sync"
13 "testing"
14 "time"
15
16 "github.com/charmbracelet/keygen"
17 "github.com/charmbracelet/soft-serve/server"
18 "github.com/charmbracelet/soft-serve/server/backend"
19 "github.com/charmbracelet/soft-serve/server/config"
20 "github.com/charmbracelet/soft-serve/server/db"
21 "github.com/charmbracelet/soft-serve/server/db/migrate"
22 "github.com/charmbracelet/soft-serve/server/test"
23 "github.com/rogpeppe/go-internal/testscript"
24 "golang.org/x/crypto/ssh"
25 _ "modernc.org/sqlite" // sqlite Driver
26)
27
28var update = flag.Bool("update", false, "update script files")
29
30func TestScript(t *testing.T) {
31 flag.Parse()
32 var lock sync.Mutex
33
34 mkkey := func(name string) (string, *keygen.SSHKeyPair) {
35 path := filepath.Join(t.TempDir(), name)
36 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
37 if err != nil {
38 t.Fatal(err)
39 }
40 return path, pair
41 }
42
43 key, admin1 := mkkey("admin1")
44 _, admin2 := mkkey("admin2")
45 _, user1 := mkkey("user1")
46
47 testscript.Run(t, testscript.Params{
48 Dir: "./testdata/",
49 UpdateScripts: *update,
50 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
51 "soft": cmdSoft(admin1.Signer()),
52 "usoft": cmdSoft(user1.Signer()),
53 "git": cmdGit(key),
54 "mkfile": cmdMkfile,
55 "readfile": cmdReadfile,
56 "dos2unix": cmdDos2Unix,
57 },
58 Setup: func(e *testscript.Env) error {
59 data := t.TempDir()
60
61 sshPort := test.RandomPort()
62 sshListen := fmt.Sprintf("localhost:%d", sshPort)
63 gitPort := test.RandomPort()
64 gitListen := fmt.Sprintf("localhost:%d", gitPort)
65 httpPort := test.RandomPort()
66 httpListen := fmt.Sprintf("localhost:%d", httpPort)
67 statsPort := test.RandomPort()
68 statsListen := fmt.Sprintf("localhost:%d", statsPort)
69 serverName := "Test Soft Serve"
70
71 e.Setenv("DATA_PATH", data)
72 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
73 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
74 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
75 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
76 e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
77 e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
78
79 cfg := config.DefaultConfig()
80 cfg.DataPath = data
81 cfg.Name = serverName
82 cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
83 cfg.SSH.ListenAddr = sshListen
84 cfg.SSH.PublicURL = "ssh://" + sshListen
85 cfg.Git.ListenAddr = gitListen
86 cfg.HTTP.ListenAddr = httpListen
87 cfg.HTTP.PublicURL = "http://" + httpListen
88 cfg.Stats.ListenAddr = statsListen
89 cfg.DB.Driver = "sqlite"
90
91 if err := cfg.Validate(); err != nil {
92 return err
93 }
94
95 ctx := config.WithContext(context.Background(), cfg)
96
97 // TODO: test postgres
98 dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
99 if err != nil {
100 return fmt.Errorf("open database: %w", err)
101 }
102
103 if err := migrate.Migrate(ctx, dbx); err != nil {
104 return fmt.Errorf("migrate database: %w", err)
105 }
106
107 ctx = db.WithContext(ctx, dbx)
108 be := backend.New(ctx, cfg, dbx)
109 ctx = backend.WithContext(ctx, be)
110
111 // prevent race condition in lipgloss...
112 // this will probably be autofixed when we start using the colors
113 // from the ssh session instead of the server.
114 // XXX: take another look at this soon
115 lock.Lock()
116 srv, err := server.NewServer(ctx)
117 if err != nil {
118 return err
119 }
120 lock.Unlock()
121
122 go func() {
123 if err := srv.Start(); err != nil {
124 e.T().Fatal(err)
125 }
126 }()
127
128 e.Defer(func() {
129 defer dbx.Close() // nolint: errcheck
130 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
131 defer cancel()
132 if err := srv.Shutdown(ctx); err != nil {
133 e.T().Fatal(err)
134 }
135 })
136
137 // wait until the server is up
138 for {
139 conn, _ := net.DialTimeout(
140 "tcp",
141 net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
142 time.Second,
143 )
144 if conn != nil {
145 conn.Close()
146 break
147 }
148 }
149
150 return nil
151 },
152 })
153}
154
155func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
156 return func(ts *testscript.TestScript, neg bool, args []string) {
157 cli, err := ssh.Dial(
158 "tcp",
159 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
160 &ssh.ClientConfig{
161 User: "admin",
162 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
163 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
164 },
165 )
166 ts.Check(err)
167 defer cli.Close()
168
169 sess, err := cli.NewSession()
170 ts.Check(err)
171 defer sess.Close()
172
173 sess.Stdout = ts.Stdout()
174 sess.Stderr = ts.Stderr()
175
176 check(ts, sess.Run(strings.Join(args, " ")), neg)
177 }
178}
179
180// P.S. Windows sucks!
181func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
182 if neg {
183 ts.Fatalf("unsupported: ! dos2unix")
184 }
185 if len(args) < 1 {
186 ts.Fatalf("usage: dos2unix paths...")
187 }
188 for _, arg := range args {
189 filename := ts.MkAbs(arg)
190 data, err := os.ReadFile(filename)
191 if err != nil {
192 ts.Fatalf("%s: %v", filename, err)
193 }
194
195 // Replace all '\r\n' with '\n'.
196 data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
197
198 if err := os.WriteFile(filename, data, 0o644); err != nil {
199 ts.Fatalf("%s: %v", filename, err)
200 }
201 }
202}
203
204var sshConfig = `
205Host *
206 UserKnownHostsFile %q
207 StrictHostKeyChecking no
208 IdentityAgent none
209 IdentitiesOnly yes
210 ServerAliveInterval 60
211`
212
213func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
214 return func(ts *testscript.TestScript, neg bool, args []string) {
215 ts.Check(os.WriteFile(
216 ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
217 []byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
218 0o600,
219 ))
220 sshArgs := []string{
221 "-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
222 "-i", filepath.ToSlash(key),
223 }
224 ts.Setenv(
225 "GIT_SSH_COMMAND",
226 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
227 )
228 args = append([]string{
229 "-c", "user.email=john@example.com",
230 "-c", "user.name=John Doe",
231 }, args...)
232 check(ts, ts.Exec("git", args...), neg)
233 }
234}
235
236func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
237 if len(args) < 2 {
238 ts.Fatalf("usage: mkfile path content")
239 }
240 check(ts, os.WriteFile(
241 ts.MkAbs(args[0]),
242 []byte(strings.Join(args[1:], " ")),
243 0o644,
244 ), neg)
245}
246
247func check(ts *testscript.TestScript, err error, neg bool) {
248 if neg && err == nil {
249 ts.Fatalf("expected error, got nil")
250 }
251 if !neg {
252 ts.Check(err)
253 }
254}
255
256func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
257 ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
258}