1package testscript
2
3import (
4 "bytes"
5 "context"
6 "flag"
7 "fmt"
8 "net"
9 "os"
10 "path/filepath"
11 "runtime"
12 "strings"
13 "sync"
14 "testing"
15 "time"
16
17 "github.com/charmbracelet/keygen"
18 "github.com/charmbracelet/soft-serve/server"
19 "github.com/charmbracelet/soft-serve/server/config"
20 "github.com/charmbracelet/soft-serve/server/test"
21 "github.com/rogpeppe/go-internal/testscript"
22 "golang.org/x/crypto/ssh"
23)
24
25var update = flag.Bool("update", false, "update script files")
26
27func TestScript(t *testing.T) {
28 flag.Parse()
29 var lock sync.Mutex
30
31 mkkey := func(name string) (string, *keygen.SSHKeyPair) {
32 path := filepath.Join(t.TempDir(), name)
33 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
34 if err != nil {
35 t.Fatal(err)
36 }
37 return path, pair
38 }
39
40 key, admin1 := mkkey("admin1")
41 _, admin2 := mkkey("admin2")
42 _, user1 := mkkey("user1")
43
44 testscript.Run(t, testscript.Params{
45 Dir: "./testdata/",
46 UpdateScripts: *update,
47 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
48 "soft": cmdSoft(admin1.Signer()),
49 "git": cmdGit(key),
50 "mkreadme": cmdMkReadme,
51 "dos2unix": cmdDos2Unix,
52 },
53 Setup: func(e *testscript.Env) error {
54 sshPort := test.RandomPort()
55 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
56 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
57 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
58 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
59 data := t.TempDir()
60 cfg := config.Config{
61 Name: "Test Soft Serve",
62 DataPath: data,
63 InitialAdminKeys: []string{admin1.AuthorizedKey()},
64 SSH: config.SSHConfig{
65 ListenAddr: fmt.Sprintf("localhost:%d", sshPort),
66 PublicURL: fmt.Sprintf("ssh://localhost:%d", sshPort),
67 KeyPath: filepath.Join(data, "ssh", "soft_serve_host_ed25519"),
68 ClientKeyPath: filepath.Join(data, "ssh", "soft_serve_client_ed25519"),
69 },
70 Git: config.GitConfig{
71 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
72 IdleTimeout: 3,
73 MaxConnections: 32,
74 },
75 HTTP: config.HTTPConfig{
76 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
77 PublicURL: fmt.Sprintf("http://localhost:%d", test.RandomPort()),
78 },
79 Stats: config.StatsConfig{
80 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
81 },
82 Log: config.LogConfig{
83 Format: "text",
84 TimeFormat: time.DateTime,
85 },
86 }
87 ctx := config.WithContext(context.Background(), &cfg)
88
89 // prevent race condition in lipgloss...
90 // this will probably be autofixed when we start using the colors
91 // from the ssh session instead of the server.
92 // XXX: take another look at this soon
93 lock.Lock()
94 srv, err := server.NewServer(ctx)
95 if err != nil {
96 return err
97 }
98 lock.Unlock()
99
100 go func() {
101 if err := srv.Start(); err != nil {
102 e.T().Fatal(err)
103 }
104 }()
105
106 e.Defer(func() {
107 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
108 defer cancel()
109 if err := srv.Shutdown(ctx); err != nil {
110 e.T().Fatal(err)
111 }
112 })
113
114 // wait until the server is up
115 for {
116 conn, _ := net.DialTimeout(
117 "tcp",
118 net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
119 time.Second,
120 )
121 if conn != nil {
122 conn.Close()
123 break
124 }
125 }
126
127 return nil
128 },
129 })
130}
131
132func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
133 return func(ts *testscript.TestScript, neg bool, args []string) {
134 cli, err := ssh.Dial(
135 "tcp",
136 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
137 &ssh.ClientConfig{
138 User: "admin",
139 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
140 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
141 },
142 )
143 ts.Check(err)
144 defer cli.Close()
145
146 sess, err := cli.NewSession()
147 ts.Check(err)
148 defer sess.Close()
149
150 sess.Stdout = ts.Stdout()
151 sess.Stderr = ts.Stderr()
152
153 check(ts, sess.Run(strings.Join(args, " ")), neg)
154 }
155}
156
157// P.S. Windows sucks!
158func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
159 if neg {
160 ts.Fatalf("unsupported: ! dos2unix")
161 }
162 if len(args) < 1 {
163 ts.Fatalf("usage: dos2unix paths...")
164 }
165 for _, arg := range args {
166 filename := ts.MkAbs(arg)
167 data, err := os.ReadFile(filename)
168 if err != nil {
169 ts.Fatalf("%s: %v", filename, err)
170 }
171
172 // Replace all '\r\n' with '\n'.
173 data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
174
175 if err := os.WriteFile(filename, data, 0o644); err != nil {
176 ts.Fatalf("%s: %v", filename, err)
177 }
178 }
179}
180
181func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
182 return func(ts *testscript.TestScript, neg bool, args []string) {
183 sshArgs := []string{
184 "-o", "StrictHostKeyChecking=no",
185 "-o", "IdentityAgent=none",
186 "-o", "IdentitiesOnly=yes",
187 "-o", "ServerAliveInterval=60",
188 // Escape the key path for Windows.
189 "-i", strings.ReplaceAll(key, `\`, `\\`),
190 }
191 // Windows null device
192 // https://stackoverflow.com/a/36746090/10913628
193 if runtime.GOOS == "windows" {
194 sshArgs = append(sshArgs, []string{
195 "-F", `$nul`,
196 "-o", `UserKnownHostsFile=$nul`,
197 }...)
198 } else {
199 sshArgs = append(sshArgs, []string{
200 "-F", "/dev/null",
201 "-o", "UserKnownHostsFile=/dev/null",
202 }...)
203 }
204 ts.Setenv(
205 "GIT_SSH_COMMAND",
206 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
207 )
208 args = append([]string{
209 "-c", "user.email=john@example.com",
210 "-c", "user.name=John Doe",
211 }, args...)
212 check(ts, ts.Exec("git", args...), neg)
213 }
214}
215
216func cmdMkReadme(ts *testscript.TestScript, neg bool, args []string) {
217 if len(args) != 1 {
218 ts.Fatalf("usage: mkreadme path")
219 }
220 content := []byte("# example\ntest project")
221 check(ts, os.WriteFile(ts.MkAbs(args[0]), content, 0o644), neg)
222}
223
224func check(ts *testscript.TestScript, err error, neg bool) {
225 if neg && err == nil {
226 ts.Fatalf("expected error, got nil")
227 }
228 if !neg {
229 ts.Check(err)
230 }
231}