1package testscript
2
3import (
4 "bytes"
5 "context"
6 "flag"
7 "fmt"
8 "net"
9 "os"
10 "path/filepath"
11 "strings"
12 "sync"
13 "testing"
14 "time"
15
16 "github.com/charmbracelet/keygen"
17 "github.com/charmbracelet/soft-serve/server"
18 "github.com/charmbracelet/soft-serve/server/config"
19 "github.com/charmbracelet/soft-serve/server/test"
20 "github.com/rogpeppe/go-internal/testscript"
21 "golang.org/x/crypto/ssh"
22)
23
24var update = flag.Bool("update", false, "update script files")
25
26func TestScript(t *testing.T) {
27 flag.Parse()
28 var lock sync.Mutex
29
30 mkkey := func(name string) (string, *keygen.SSHKeyPair) {
31 path := filepath.Join(t.TempDir(), name)
32 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
33 if err != nil {
34 t.Fatal(err)
35 }
36 return path, pair
37 }
38
39 key, admin1 := mkkey("admin1")
40 _, admin2 := mkkey("admin2")
41 _, user1 := mkkey("user1")
42
43 testscript.Run(t, testscript.Params{
44 Dir: "./testdata/",
45 UpdateScripts: *update,
46 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
47 "soft": cmdSoft(admin1.Signer()),
48 "usoft": cmdSoft(user1.Signer()),
49 "git": cmdGit(key),
50 "mkfile": cmdMkfile,
51 "dos2unix": cmdDos2Unix,
52 },
53 Setup: func(e *testscript.Env) error {
54 sshPort := test.RandomPort()
55 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
56 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
57 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
58 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
59 e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
60 e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
61 data := t.TempDir()
62 cfg := config.Config{
63 Name: "Test Soft Serve",
64 DataPath: data,
65 InitialAdminKeys: []string{admin1.AuthorizedKey()},
66 SSH: config.SSHConfig{
67 ListenAddr: fmt.Sprintf("localhost:%d", sshPort),
68 PublicURL: fmt.Sprintf("ssh://localhost:%d", sshPort),
69 KeyPath: filepath.Join(data, "ssh", "soft_serve_host_ed25519"),
70 ClientKeyPath: filepath.Join(data, "ssh", "soft_serve_client_ed25519"),
71 },
72 Git: config.GitConfig{
73 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
74 IdleTimeout: 3,
75 MaxConnections: 32,
76 },
77 HTTP: config.HTTPConfig{
78 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
79 PublicURL: fmt.Sprintf("http://localhost:%d", test.RandomPort()),
80 },
81 Stats: config.StatsConfig{
82 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
83 },
84 Log: config.LogConfig{
85 Format: "text",
86 TimeFormat: time.DateTime,
87 },
88 }
89 ctx := config.WithContext(context.Background(), &cfg)
90
91 // prevent race condition in lipgloss...
92 // this will probably be autofixed when we start using the colors
93 // from the ssh session instead of the server.
94 // XXX: take another look at this soon
95 lock.Lock()
96 srv, err := server.NewServer(ctx)
97 if err != nil {
98 return err
99 }
100 lock.Unlock()
101
102 go func() {
103 if err := srv.Start(); err != nil {
104 e.T().Fatal(err)
105 }
106 }()
107
108 e.Defer(func() {
109 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
110 defer cancel()
111 if err := srv.Shutdown(ctx); err != nil {
112 e.T().Fatal(err)
113 }
114 })
115
116 // wait until the server is up
117 for {
118 conn, _ := net.DialTimeout(
119 "tcp",
120 net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
121 time.Second,
122 )
123 if conn != nil {
124 conn.Close()
125 break
126 }
127 }
128
129 return nil
130 },
131 })
132}
133
134func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
135 return func(ts *testscript.TestScript, neg bool, args []string) {
136 cli, err := ssh.Dial(
137 "tcp",
138 net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
139 &ssh.ClientConfig{
140 User: "admin",
141 Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
142 HostKeyCallback: ssh.InsecureIgnoreHostKey(),
143 },
144 )
145 ts.Check(err)
146 defer cli.Close()
147
148 sess, err := cli.NewSession()
149 ts.Check(err)
150 defer sess.Close()
151
152 sess.Stdout = ts.Stdout()
153 sess.Stderr = ts.Stderr()
154
155 check(ts, sess.Run(strings.Join(args, " ")), neg)
156 }
157}
158
159// P.S. Windows sucks!
160func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
161 if neg {
162 ts.Fatalf("unsupported: ! dos2unix")
163 }
164 if len(args) < 1 {
165 ts.Fatalf("usage: dos2unix paths...")
166 }
167 for _, arg := range args {
168 filename := ts.MkAbs(arg)
169 data, err := os.ReadFile(filename)
170 if err != nil {
171 ts.Fatalf("%s: %v", filename, err)
172 }
173
174 // Replace all '\r\n' with '\n'.
175 data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
176
177 if err := os.WriteFile(filename, data, 0o644); err != nil {
178 ts.Fatalf("%s: %v", filename, err)
179 }
180 }
181}
182
183var sshConfig = `
184Host *
185 UserKnownHostsFile %q
186 StrictHostKeyChecking no
187 IdentityAgent none
188 IdentitiesOnly yes
189 ServerAliveInterval 60
190`
191
192func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
193 return func(ts *testscript.TestScript, neg bool, args []string) {
194 ts.Check(os.WriteFile(
195 ts.Getenv("SSH_KNOWN_CONFIG_FILE"),
196 []byte(fmt.Sprintf(sshConfig, ts.Getenv("SSH_KNOWN_HOSTS_FILE"))),
197 0o600,
198 ))
199 sshArgs := []string{
200 "-F", filepath.ToSlash(ts.Getenv("SSH_KNOWN_CONFIG_FILE")),
201 "-i", filepath.ToSlash(key),
202 }
203 ts.Setenv(
204 "GIT_SSH_COMMAND",
205 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
206 )
207 args = append([]string{
208 "-c", "user.email=john@example.com",
209 "-c", "user.name=John Doe",
210 }, args...)
211 check(ts, ts.Exec("git", args...), neg)
212 }
213}
214
215func cmdMkfile(ts *testscript.TestScript, neg bool, args []string) {
216 if len(args) < 2 {
217 ts.Fatalf("usage: mkfile path content")
218 }
219 check(ts, os.WriteFile(
220 ts.MkAbs(args[0]),
221 []byte(strings.Join(args[1:], " ")),
222 0o644,
223 ), neg)
224}
225
226func check(ts *testscript.TestScript, err error, neg bool) {
227 if neg && err == nil {
228 ts.Fatalf("expected error, got nil")
229 }
230 if !neg {
231 ts.Check(err)
232 }
233}