1package testscript
2
3import (
4 "bytes"
5 "context"
6 "flag"
7 "fmt"
8 "net"
9 "os"
10 "path/filepath"
11 "strings"
12 "sync"
13 "testing"
14 "time"
15
16 "github.com/charmbracelet/keygen"
17 "github.com/charmbracelet/soft-serve/server"
18 "github.com/charmbracelet/soft-serve/server/backend"
19 "github.com/charmbracelet/soft-serve/server/config"
20 "github.com/charmbracelet/soft-serve/server/db"
21 "github.com/charmbracelet/soft-serve/server/db/migrate"
22 "github.com/charmbracelet/soft-serve/server/test"
23 "github.com/rogpeppe/go-internal/testscript"
24 "golang.org/x/crypto/ssh"
25 _ "modernc.org/sqlite" // sqlite Driver
26)
27
28var update = flag.Bool("update", false, "update script files")
29
30func TestScript(t *testing.T) {
31 flag.Parse()
32 var lock sync.Mutex
33
34 mkkey := func(name string) (string, *keygen.SSHKeyPair) {
35 path := filepath.Join(t.TempDir(), name)
36 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
37 if err != nil {
38 t.Fatal(err)
39 }
40 return path, pair
41 }
42
43 key, admin1 := mkkey("admin1")
44 _, admin2 := mkkey("admin2")
45 _, user1 := mkkey("user1")
46
47 testscript.Run(t, testscript.Params{
48 Dir: "./testdata/",
49 UpdateScripts: *update,
50 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
51 "soft": cmdSoft(admin1.Signer()),
52 "usoft": cmdSoft(user1.Signer()),
53 "git": cmdGit(key),
54 "mkfile": cmdMkfile,
55 "dos2unix": cmdDos2Unix,
56 },
57 Setup: func(e *testscript.Env) error {
58 data := t.TempDir()
59
60 sshPort := test.RandomPort()
61 sshListen := fmt.Sprintf("localhost:%d", sshPort)
62 gitPort := test.RandomPort()
63 gitListen := fmt.Sprintf("localhost:%d", gitPort)
64 httpPort := test.RandomPort()
65 httpListen := fmt.Sprintf("localhost:%d", httpPort)
66 statsPort := test.RandomPort()
67 statsListen := fmt.Sprintf("localhost:%d", statsPort)
68 serverName := "Test Soft Serve"
69
70 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
71 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
72 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
73 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
74 e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
75 e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
76
77 cfg := config.DefaultConfig()
78 cfg.DataPath = data
79 cfg.Name = serverName
80 cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
81 cfg.SSH.ListenAddr = sshListen
82 cfg.SSH.PublicURL = "ssh://" + sshListen
83 cfg.Git.ListenAddr = gitListen
84 cfg.HTTP.ListenAddr = httpListen
85 cfg.HTTP.PublicURL = "http://" + httpListen
86 cfg.Stats.ListenAddr = statsListen
87 cfg.DB.Driver = "sqlite"
88
89 if err := cfg.Validate(); err != nil {
90 return err
91 }
92
93 ctx := config.WithContext(context.Background(), cfg)
94
95 // TODO: test postgres
96 dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
97 if err != nil {
98 return fmt.Errorf("open database: %w", err)
99 }
100
101 if err := migrate.Migrate(ctx, dbx); err != nil {
102 return fmt.Errorf("migrate database: %w", err)
103 }
104
105 ctx = db.WithContext(ctx, dbx)
106 be := backend.New(ctx, cfg, dbx)
107 ctx = backend.WithContext(ctx, be)
108
109 // prevent race condition in lipgloss...
110 // this will probably be autofixed when we start using the colors
111 // from the ssh session instead of the server.
112 // XXX: take another look at this soon
113 lock.Lock()
114 srv, err := server.NewServer(ctx)
115 if err != nil {
116 return err
117 }
118 lock.Unlock()
119
120 go func() {
121 if err := srv.Start(); err != nil {
122 e.T().Fatal(err)
123 }
124 }()
125
126 e.Defer(func() {
127 defer dbx.Close() // nolint: errcheck
128 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
129 defer cancel()
130 if err := srv.Shutdown(ctx); err != nil {
131 e.T().Fatal(err)
132 }
133 })
134
135 // wait until the server is up
136 for {
137 conn, _ := net.DialTimeout(
138 "tcp",
139 net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
140 time.Second,
141 )
142 if conn != nil {
143 conn.Close()
144 break
145 }
146 }
147
148 return nil
149 },
150 })
151}
152
153func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
154 return func(ts *testscript.TestScript, neg bool, args []string) {
155 cli, err := ssh.Dial(
156 "tcp",
157 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
158 &ssh.ClientConfig{
159 User: "admin",
160 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
161 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
162 },
163 )
164 ts.Check(err)
165 defer cli.Close()
166
167 sess, err := cli.NewSession()
168 ts.Check(err)
169 defer sess.Close()
170
171 sess.Stdout = ts.Stdout()
172 sess.Stderr = ts.Stderr()
173
174 check(ts, sess.Run(strings.Join(args, " ")), neg)
175 }
176}
177
178// P.S. Windows sucks!
179func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
180 if neg {
181 ts.Fatalf("unsupported: ! dos2unix")
182 }
183 if len(args) < 1 {
184 ts.Fatalf("usage: dos2unix paths...")
185 }
186 for _, arg := range args {
187 filename := ts.MkAbs(arg)
188 data, err := os.ReadFile(filename)
189 if err != nil {
190 ts.Fatalf("%s: %v", filename, err)
191 }
192
193 // Replace all '\r\n' with '\n'.
194 data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
195
196 if err := os.WriteFile(filename, data, 0o644); err != nil {
197 ts.Fatalf("%s: %v", filename, err)
198 }
199 }
200}
201
202var sshConfig = `
203Host *
204 UserKnownHostsFile %q
205 StrictHostKeyChecking no
206 IdentityAgent none
207 IdentitiesOnly yes
208 ServerAliveInterval 60
209`
210
211func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
212 return func(ts *testscript.TestScript, neg bool, args []string) {
213 ts.Check(os.WriteFile(
214 ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
215 []byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
216 0o600,
217 ))
218 sshArgs := []string{
219 "-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
220 "-i", filepath.ToSlash(key),
221 }
222 ts.Setenv(
223 "GIT_SSH_COMMAND",
224 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
225 )
226 args = append([]string{
227 "-c", "user.email=john@example.com",
228 "-c", "user.name=John Doe",
229 }, args...)
230 check(ts, ts.Exec("git", args...), neg)
231 }
232}
233
234func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
235 if len(args) < 2 {
236 ts.Fatalf("usage: mkfile path content")
237 }
238 check(ts, os.WriteFile(
239 ts.MkAbs(args[0]),
240 []byte(strings.Join(args[1:], " ")),
241 0o644,
242 ), neg)
243}
244
245func check(ts *testscript.TestScript, err error, neg bool) {
246 if neg && err == nil {
247 ts.Fatalf("expected error, got nil")
248 }
249 if !neg {
250 ts.Check(err)
251 }
252}