1package testscript
2
3import (
4 "bytes"
5 "context"
6 "flag"
7 "fmt"
8 "net"
9 "os"
10 "path/filepath"
11 "strings"
12 "sync"
13 "testing"
14 "time"
15
16 "github.com/charmbracelet/keygen"
17 "github.com/charmbracelet/soft-serve/server"
18 "github.com/charmbracelet/soft-serve/server/config"
19 "github.com/charmbracelet/soft-serve/server/test"
20 "github.com/rogpeppe/go-internal/testscript"
21 "golang.org/x/crypto/ssh"
22)
23
24var update = flag.Bool("update", false, "update script files")
25
26func TestScript(t *testing.T) {
27 flag.Parse()
28 var lock sync.Mutex
29
30 mkkey := func(name string) (string, *keygen.SSHKeyPair) {
31 path := filepath.Join(t.TempDir(), name)
32 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
33 if err != nil {
34 t.Fatal(err)
35 }
36 return path, pair
37 }
38
39 key, admin1 := mkkey("admin1")
40 _, admin2 := mkkey("admin2")
41 _, user1 := mkkey("user1")
42
43 testscript.Run(t, testscript.Params{
44 Dir: "./testdata/",
45 UpdateScripts: *update,
46 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
47 "soft": cmdSoft(admin1.Signer()),
48 "git": cmdGit(key),
49 "mkfile": cmdMkfile,
50 "dos2unix": cmdDos2Unix,
51 },
52 Setup: func(e *testscript.Env) error {
53 sshPort := test.RandomPort()
54 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
55 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
56 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
57 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
58 e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
59 e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
60 data := t.TempDir()
61 cfg := config.Config{
62 Name: "Test Soft Serve",
63 DataPath: data,
64 InitialAdminKeys: []string{admin1.AuthorizedKey()},
65 SSH: config.SSHConfig{
66 ListenAddr: fmt.Sprintf("localhost:%d", sshPort),
67 PublicURL: fmt.Sprintf("ssh://localhost:%d", sshPort),
68 KeyPath: filepath.Join(data, "ssh", "soft_serve_host_ed25519"),
69 ClientKeyPath: filepath.Join(data, "ssh", "soft_serve_client_ed25519"),
70 },
71 Git: config.GitConfig{
72 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
73 IdleTimeout: 3,
74 MaxConnections: 32,
75 },
76 HTTP: config.HTTPConfig{
77 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
78 PublicURL: fmt.Sprintf("http://localhost:%d", test.RandomPort()),
79 },
80 Stats: config.StatsConfig{
81 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
82 },
83 Log: config.LogConfig{
84 Format: "text",
85 TimeFormat: time.DateTime,
86 },
87 }
88 ctx := config.WithContext(context.Background(), &cfg)
89
90 // prevent race condition in lipgloss...
91 // this will probably be autofixed when we start using the colors
92 // from the ssh session instead of the server.
93 // XXX: take another look at this soon
94 lock.Lock()
95 srv, err := server.NewServer(ctx)
96 if err != nil {
97 return err
98 }
99 lock.Unlock()
100
101 go func() {
102 if err := srv.Start(); err != nil {
103 e.T().Fatal(err)
104 }
105 }()
106
107 e.Defer(func() {
108 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
109 defer cancel()
110 if err := srv.Shutdown(ctx); err != nil {
111 e.T().Fatal(err)
112 }
113 })
114
115 // wait until the server is up
116 for {
117 conn, _ := net.DialTimeout(
118 "tcp",
119 net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
120 time.Second,
121 )
122 if conn != nil {
123 conn.Close()
124 break
125 }
126 }
127
128 return nil
129 },
130 })
131}
132
133func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
134 return func(ts *testscript.TestScript, neg bool, args []string) {
135 cli, err := ssh.Dial(
136 "tcp",
137 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
138 &ssh.ClientConfig{
139 User: "admin",
140 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
141 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
142 },
143 )
144 ts.Check(err)
145 defer cli.Close()
146
147 sess, err := cli.NewSession()
148 ts.Check(err)
149 defer sess.Close()
150
151 sess.Stdout = ts.Stdout()
152 sess.Stderr = ts.Stderr()
153
154 check(ts, sess.Run(strings.Join(args, " ")), neg)
155 }
156}
157
158// P.S. Windows sucks!
159func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
160 if neg {
161 ts.Fatalf("unsupported: ! dos2unix")
162 }
163 if len(args) < 1 {
164 ts.Fatalf("usage: dos2unix paths...")
165 }
166 for _, arg := range args {
167 filename := ts.MkAbs(arg)
168 data, err := os.ReadFile(filename)
169 if err != nil {
170 ts.Fatalf("%s: %v", filename, err)
171 }
172
173 // Replace all '\r\n' with '\n'.
174 data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
175
176 if err := os.WriteFile(filename, data, 0o644); err != nil {
177 ts.Fatalf("%s: %v", filename, err)
178 }
179 }
180}
181
182var sshConfig = `
183Host *
184 UserKnownHostsFile %q
185 StrictHostKeyChecking no
186 IdentityAgent none
187 IdentitiesOnly yes
188 ServerAliveInterval 60
189`
190
191func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
192 return func(ts *testscript.TestScript, neg bool, args []string) {
193 ts.Check(os.WriteFile(
194 ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
195 []byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
196 0o600,
197 ))
198 sshArgs := []string{
199 "-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
200 "-i", filepath.ToSlash(key),
201 }
202 ts.Setenv(
203 "GIT_SSH_COMMAND",
204 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
205 )
206 args = append([]string{
207 "-c", "user.email=john@example.com",
208 "-c", "user.name=John Doe",
209 }, args...)
210 check(ts, ts.Exec("git", args...), neg)
211 }
212}
213
214func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
215 if len(args) < 2 {
216 ts.Fatalf("usage: mkfile path content")
217 }
218 check(ts, os.WriteFile(
219 ts.MkAbs(args[0]),
220 []byte(strings.Join(args[1:], " ")),
221 0o644,
222 ), neg)
223}
224
225func check(ts *testscript.TestScript, err error, neg bool) {
226 if neg && err == nil {
227 ts.Fatalf("expected error, got nil")
228 }
229 if !neg {
230 ts.Check(err)
231 }
232}