1package testscript
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "net"
8 "os"
9 "path/filepath"
10 "strings"
11 "sync"
12 "testing"
13 "time"
14
15 "github.com/charmbracelet/soft-serve/server"
16 "github.com/charmbracelet/soft-serve/server/config"
17 "github.com/charmbracelet/soft-serve/server/test"
18 "github.com/rogpeppe/go-internal/testscript"
19)
20
21var update = flag.Bool("update", false, "update script files")
22
23func TestScript(t *testing.T) {
24 flag.Parse()
25 var lock sync.Mutex
26
27 t.Setenv("SOFT_SERVE_TEST_NO_HOOKS", "1")
28
29 // we'll use this key to talk with soft serve, and since testscript changes
30 // the cwd, we need to get its full path here
31 key, err := filepath.Abs("./testdata/admin1")
32 if err != nil {
33 t.Fatal(err)
34 }
35
36 // git does not handle 0600, and on clone, will save the files with its
37 // default perm, 0644, which is too open for ssh.
38 for _, f := range []string{
39 "admin1",
40 "admin2",
41 "user1",
42 "user2",
43 } {
44 if err := os.Chmod(filepath.Join("./testdata/", f), 0o600); err != nil {
45 t.Fatal(err)
46 }
47 }
48
49 sshArgs := []string{
50 "-F", "/dev/null",
51 "-o", "StrictHostKeyChecking=no",
52 "-o", "UserKnownHostsFile=/dev/null",
53 "-o", "IdentityAgent=none",
54 "-o", "IdentitiesOnly=yes",
55 "-i", key,
56 }
57
58 check := func(ts *testscript.TestScript, err error, neg bool) {
59 if neg && err == nil {
60 ts.Fatalf("expected error, got nil")
61 }
62 if !neg {
63 ts.Check(err)
64 }
65 }
66
67 testscript.Run(t, testscript.Params{
68 Dir: "testdata/script",
69 UpdateScripts: *update,
70 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
71 "soft": func(ts *testscript.TestScript, neg bool, args []string) {
72 args = append(
73 sshArgs,
74 append([]string{
75 "-p", ts.Getenv("SSH_PORT"),
76 "localhost",
77 "--",
78 }, args...)...,
79 )
80 check(ts, ts.Exec("ssh", args...), neg)
81 },
82 "git": func(ts *testscript.TestScript, neg bool, args []string) {
83 ts.Setenv(
84 "GIT_SSH_COMMAND",
85 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
86 )
87 args = append([]string{
88 "-c", "user.email=john@example.com",
89 "-c", "user.name=John Doe",
90 }, args...)
91 check(ts, ts.Exec("git", args...), neg)
92 },
93 "mkreadme": func(ts *testscript.TestScript, neg bool, args []string) {
94 if len(args) != 1 {
95 ts.Fatalf("must have exactly 1 arg, the filename, got %d", len(args))
96 }
97 check(ts, os.WriteFile(ts.MkAbs(args[0]), []byte("# example\ntest project"), 0o644), neg)
98 },
99 },
100 Setup: func(e *testscript.Env) error {
101 sshPort := test.RandomPort()
102 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
103 data := t.TempDir()
104 cfg := config.Config{
105 Name: "Test Soft Serve",
106 DataPath: data,
107 InitialAdminKeys: []string{
108 "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJI/1tawpdPmzuJcTGTJ+QReqB6cRUdKj4iQIdJUFdrl",
109 },
110 SSH: config.SSHConfig{
111 ListenAddr: fmt.Sprintf("localhost:%d", sshPort),
112 PublicURL: fmt.Sprintf("ssh://localhost:%d", sshPort),
113 KeyPath: filepath.Join(data, "ssh", "soft_serve_host_ed25519"),
114 ClientKeyPath: filepath.Join(data, "ssh", "soft_serve_client_ed25519"),
115 },
116 Git: config.GitConfig{
117 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
118 IdleTimeout: 3,
119 MaxConnections: 32,
120 },
121 HTTP: config.HTTPConfig{
122 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
123 PublicURL: fmt.Sprintf("http://localhost:%d", test.RandomPort()),
124 },
125 Stats: config.StatsConfig{
126 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
127 },
128 Log: config.LogConfig{
129 Format: "text",
130 TimeFormat: time.DateTime,
131 },
132 }
133 ctx := config.WithContext(context.Background(), &cfg)
134
135 // prevent race condition in lipgloss...
136 // this will probably be autofixed when we start using the colors
137 // from the ssh session instead of the server.
138 // XXX: take another look at this soon
139 lock.Lock()
140 srv, err := server.NewServer(ctx)
141 if err != nil {
142 return err
143 }
144 lock.Unlock()
145
146 go func() {
147 if err := srv.Start(); err != nil {
148 e.T().Fatal(err)
149 }
150 }()
151
152 e.Defer(func() {
153 ctx, cancel := context.WithTimeout(context.Background(), time.Second)
154 defer cancel()
155 if err := srv.Shutdown(ctx); err != nil {
156 e.T().Fatal(err)
157 }
158 })
159
160 // wait until the server is up
161 for {
162 conn, _ := net.DialTimeout(
163 "tcp",
164 net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
165 time.Second,
166 )
167 if conn != nil {
168 conn.Close()
169 break
170 }
171 }
172
173 return nil
174 },
175 })
176}