1package testscript
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "net"
8 "os"
9 "os/exec"
10 "path/filepath"
11 "runtime"
12 "strings"
13 "sync"
14 "testing"
15 "time"
16
17 "github.com/charmbracelet/soft-serve/server"
18 "github.com/charmbracelet/soft-serve/server/config"
19 "github.com/charmbracelet/soft-serve/server/test"
20 "github.com/rogpeppe/go-internal/testscript"
21)
22
23var update = flag.Bool("update", false, "update script files")
24
25func TestScript(t *testing.T) {
26 flag.Parse()
27 var lock sync.Mutex
28
29 t.Setenv("SOFT_SERVE_TEST_NO_HOOKS", "1")
30
31 // we'll use this key to talk with soft serve, and since testscript changes
32 // the cwd, we need to get its full path here
33 key, err := filepath.Abs("./testdata/admin1")
34 if err != nil {
35 t.Fatal(err)
36 }
37
38 // git does not handle 0600, and on clone, will save the files with its
39 // default perm, 0644, which is too open for ssh.
40 for _, f := range []string{
41 "admin1",
42 "admin2",
43 "user1",
44 "user2",
45 } {
46 if err := os.Chmod(filepath.Join("./testdata/", f), 0o600); err != nil {
47 t.Fatal(err)
48 }
49 }
50
51 sshArgs := []string{
52 "-F", "/dev/null",
53 "-o", "StrictHostKeyChecking=no",
54 "-o", "UserKnownHostsFile=/dev/null",
55 "-o", "IdentityAgent=none",
56 "-o", "IdentitiesOnly=yes",
57 "-o", "ServerAliveInterval=60",
58 "-i", key,
59 }
60
61 check := func(ts *testscript.TestScript, err error, neg bool) {
62 if neg && err == nil {
63 ts.Fatalf("expected error, got nil")
64 }
65 if !neg {
66 ts.Check(err)
67 }
68 }
69
70 testscript.Run(t, testscript.Params{
71 Dir: "testdata/script",
72 UpdateScripts: *update,
73 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
74 "soft": func(ts *testscript.TestScript, neg bool, args []string) {
75 args = append(
76 sshArgs,
77 append([]string{
78 "-p", ts.Getenv("SSH_PORT"),
79 "localhost",
80 "--",
81 }, args...)...,
82 )
83 if runtime.GOOS == "windows" {
84 cmd := exec.Command("ssh", args...)
85 out, err := cmd.CombinedOutput()
86 ts.Logf("RUNNING %v: output: %s error: %v", cmd.Args, string(out), err)
87 }
88 check(ts, ts.Exec("ssh", args...), neg)
89 },
90 "git": func(ts *testscript.TestScript, neg bool, args []string) {
91 ts.Setenv(
92 "GIT_SSH_COMMAND",
93 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
94 )
95 args = append([]string{
96 "-c", "user.email=john@example.com",
97 "-c", "user.name=John Doe",
98 }, args...)
99 check(ts, ts.Exec("git", args...), neg)
100 },
101 "mkreadme": func(ts *testscript.TestScript, neg bool, args []string) {
102 if len(args) != 1 {
103 ts.Fatalf("must have exactly 1 arg, the filename, got %d", len(args))
104 }
105 check(ts, os.WriteFile(ts.MkAbs(args[0]), []byte("# example\ntest project"), 0o644), neg)
106 },
107 },
108 Setup: func(e *testscript.Env) error {
109 sshPort := test.RandomPort()
110 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
111 data := t.TempDir()
112 cfg := config.Config{
113 Name: "Test Soft Serve",
114 DataPath: data,
115 InitialAdminKeys: []string{
116 "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJI/1tawpdPmzuJcTGTJ+QReqB6cRUdKj4iQIdJUFdrl",
117 },
118 SSH: config.SSHConfig{
119 ListenAddr: fmt.Sprintf("localhost:%d", sshPort),
120 PublicURL: fmt.Sprintf("ssh://localhost:%d", sshPort),
121 KeyPath: filepath.Join(data, "ssh", "soft_serve_host_ed25519"),
122 ClientKeyPath: filepath.Join(data, "ssh", "soft_serve_client_ed25519"),
123 },
124 Git: config.GitConfig{
125 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
126 IdleTimeout: 3,
127 MaxConnections: 32,
128 },
129 HTTP: config.HTTPConfig{
130 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
131 PublicURL: fmt.Sprintf("http://localhost:%d", test.RandomPort()),
132 },
133 Stats: config.StatsConfig{
134 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
135 },
136 Log: config.LogConfig{
137 Format: "text",
138 TimeFormat: time.DateTime,
139 },
140 }
141 ctx := config.WithContext(context.Background(), &cfg)
142
143 // prevent race condition in lipgloss...
144 // this will probably be autofixed when we start using the colors
145 // from the ssh session instead of the server.
146 // XXX: take another look at this soon
147 lock.Lock()
148 srv, err := server.NewServer(ctx)
149 if err != nil {
150 return err
151 }
152 lock.Unlock()
153
154 go func() {
155 if err := srv.Start(); err != nil {
156 e.T().Fatal(err)
157 }
158 }()
159
160 e.Defer(func() {
161 ctx, cancel := context.WithTimeout(context.Background(), time.Second)
162 defer cancel()
163 if err := srv.Shutdown(ctx); err != nil {
164 e.T().Fatal(err)
165 }
166 })
167
168 // wait until the server is up
169 for {
170 conn, _ := net.DialTimeout(
171 "tcp",
172 net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
173 time.Second,
174 )
175 if conn != nil {
176 conn.Close()
177 break
178 }
179 }
180
181 return nil
182 },
183 })
184}