1package testscript
2
3import (
4 "bytes"
5 "context"
6 "flag"
7 "fmt"
8 "net"
9 "os"
10 "path/filepath"
11 "strings"
12 "sync"
13 "testing"
14 "time"
15
16 "github.com/charmbracelet/keygen"
17 "github.com/charmbracelet/soft-serve/server"
18 "github.com/charmbracelet/soft-serve/server/backend"
19 "github.com/charmbracelet/soft-serve/server/config"
20 "github.com/charmbracelet/soft-serve/server/db"
21 "github.com/charmbracelet/soft-serve/server/db/migrate"
22 "github.com/charmbracelet/soft-serve/server/store"
23 "github.com/charmbracelet/soft-serve/server/store/database"
24 "github.com/charmbracelet/soft-serve/server/test"
25 "github.com/rogpeppe/go-internal/testscript"
26 "golang.org/x/crypto/ssh"
27 _ "modernc.org/sqlite" // sqlite Driver
28)
29
30var update = flag.Bool("update", false, "update script files")
31
32func TestScript(t *testing.T) {
33 flag.Parse()
34 var lock sync.Mutex
35
36 mkkey := func(name string) (string, *keygen.SSHKeyPair) {
37 path := filepath.Join(t.TempDir(), name)
38 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
39 if err != nil {
40 t.Fatal(err)
41 }
42 return path, pair
43 }
44
45 key, admin1 := mkkey("admin1")
46 _, admin2 := mkkey("admin2")
47 _, user1 := mkkey("user1")
48
49 testscript.Run(t, testscript.Params{
50 Dir: "./testdata/",
51 UpdateScripts: *update,
52 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
53 "soft": cmdSoft(admin1.Signer()),
54 "usoft": cmdSoft(user1.Signer()),
55 "git": cmdGit(key),
56 "mkfile": cmdMkfile,
57 "readfile": cmdReadfile,
58 "dos2unix": cmdDos2Unix,
59 },
60 Setup: func(e *testscript.Env) error {
61 data := t.TempDir()
62
63 sshPort := test.RandomPort()
64 sshListen := fmt.Sprintf("localhost:%d", sshPort)
65 gitPort := test.RandomPort()
66 gitListen := fmt.Sprintf("localhost:%d", gitPort)
67 httpPort := test.RandomPort()
68 httpListen := fmt.Sprintf("localhost:%d", httpPort)
69 statsPort := test.RandomPort()
70 statsListen := fmt.Sprintf("localhost:%d", statsPort)
71 serverName := "Test Soft Serve"
72
73 e.Setenv("DATA_PATH", data)
74 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
75 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
76 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
77 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
78 e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
79 e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
80
81 cfg := config.DefaultConfig()
82 cfg.DataPath = data
83 cfg.Name = serverName
84 cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
85 cfg.SSH.ListenAddr = sshListen
86 cfg.SSH.PublicURL = "ssh://" + sshListen
87 cfg.Git.ListenAddr = gitListen
88 cfg.HTTP.ListenAddr = httpListen
89 cfg.HTTP.PublicURL = "http://" + httpListen
90 cfg.Stats.ListenAddr = statsListen
91 cfg.DB.Driver = "sqlite"
92
93 if err := cfg.Validate(); err != nil {
94 return err
95 }
96
97 ctx := config.WithContext(context.Background(), cfg)
98
99 // TODO: test postgres
100 dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
101 if err != nil {
102 return fmt.Errorf("open database: %w", err)
103 }
104
105 if err := migrate.Migrate(ctx, dbx); err != nil {
106 return fmt.Errorf("migrate database: %w", err)
107 }
108
109 ctx = db.WithContext(ctx, dbx)
110 datastore := database.New(ctx, dbx)
111 ctx = store.WithContext(ctx, datastore)
112 be := backend.New(ctx, cfg, dbx)
113 ctx = backend.WithContext(ctx, be)
114
115 // prevent race condition in lipgloss...
116 // this will probably be autofixed when we start using the colors
117 // from the ssh session instead of the server.
118 // XXX: take another look at this soon
119 lock.Lock()
120 srv, err := server.NewServer(ctx)
121 if err != nil {
122 return err
123 }
124 lock.Unlock()
125
126 go func() {
127 if err := srv.Start(); err != nil {
128 e.T().Fatal(err)
129 }
130 }()
131
132 e.Defer(func() {
133 defer dbx.Close() // nolint: errcheck
134 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
135 defer cancel()
136 if err := srv.Shutdown(ctx); err != nil {
137 e.T().Fatal(err)
138 }
139 })
140
141 // wait until the server is up
142 for {
143 conn, _ := net.DialTimeout(
144 "tcp",
145 net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
146 time.Second,
147 )
148 if conn != nil {
149 conn.Close()
150 break
151 }
152 }
153
154 return nil
155 },
156 })
157}
158
159func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
160 return func(ts *testscript.TestScript, neg bool, args []string) {
161 cli, err := ssh.Dial(
162 "tcp",
163 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
164 &ssh.ClientConfig{
165 User: "admin",
166 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
167 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
168 },
169 )
170 ts.Check(err)
171 defer cli.Close()
172
173 sess, err := cli.NewSession()
174 ts.Check(err)
175 defer sess.Close()
176
177 sess.Stdout = ts.Stdout()
178 sess.Stderr = ts.Stderr()
179
180 check(ts, sess.Run(strings.Join(args, " ")), neg)
181 }
182}
183
184// P.S. Windows sucks!
185func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
186 if neg {
187 ts.Fatalf("unsupported: ! dos2unix")
188 }
189 if len(args) < 1 {
190 ts.Fatalf("usage: dos2unix paths...")
191 }
192 for _, arg := range args {
193 filename := ts.MkAbs(arg)
194 data, err := os.ReadFile(filename)
195 if err != nil {
196 ts.Fatalf("%s: %v", filename, err)
197 }
198
199 // Replace all '\r\n' with '\n'.
200 data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
201
202 if err := os.WriteFile(filename, data, 0o644); err != nil {
203 ts.Fatalf("%s: %v", filename, err)
204 }
205 }
206}
207
208var sshConfig = `
209Host *
210 UserKnownHostsFile %q
211 StrictHostKeyChecking no
212 IdentityAgent none
213 IdentitiesOnly yes
214 ServerAliveInterval 60
215`
216
217func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
218 return func(ts *testscript.TestScript, neg bool, args []string) {
219 ts.Check(os.WriteFile(
220 ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
221 []byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
222 0o600,
223 ))
224 sshArgs := []string{
225 "-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
226 "-i", filepath.ToSlash(key),
227 }
228 ts.Setenv(
229 "GIT_SSH_COMMAND",
230 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
231 )
232 args = append([]string{
233 "-c", "user.email=john@example.com",
234 "-c", "user.name=John Doe",
235 }, args...)
236 check(ts, ts.Exec("git", args...), neg)
237 }
238}
239
240func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
241 if len(args) < 2 {
242 ts.Fatalf("usage: mkfile path content")
243 }
244 check(ts, os.WriteFile(
245 ts.MkAbs(args[0]),
246 []byte(strings.Join(args[1:], " ")),
247 0o644,
248 ), neg)
249}
250
251func check(ts *testscript.TestScript, err error, neg bool) {
252 if neg && err == nil {
253 ts.Fatalf("expected error, got nil")
254 }
255 if !neg {
256 ts.Check(err)
257 }
258}
259
260func cmdReadfile(ts *testscript.TestScript, neg bool, args []string) {
261 ts.Stdout().Write([]byte(ts.ReadFile(args[0])))
262}