1package testscript
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "net"
8 "os"
9 "os/exec"
10 "path/filepath"
11 "runtime"
12 "strings"
13 "sync"
14 "testing"
15 "time"
16
17 "github.com/charmbracelet/keygen"
18 "github.com/charmbracelet/soft-serve/server"
19 "github.com/charmbracelet/soft-serve/server/config"
20 "github.com/charmbracelet/soft-serve/server/test"
21 "github.com/rogpeppe/go-internal/testscript"
22)
23
24var update = flag.Bool("update", false, "update script files")
25
26func TestScript(t *testing.T) {
27 flag.Parse()
28 var lock sync.Mutex
29
30 t.Setenv("SOFT_SERVE_TEST_NO_HOOKS", "1")
31
32 mkkey := func(name string) (string, *keygen.SSHKeyPair) {
33 path := filepath.Join(t.TempDir(), name)
34 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
35 if err != nil {
36 t.Fatal(err)
37 }
38 return path, pair
39 }
40
41 key, admin1 := mkkey("admin1")
42 _, admin2 := mkkey("admin2")
43 _, user1 := mkkey("user1")
44
45 sshArgs := []string{
46 "-F", "/dev/null",
47 "-o", "StrictHostKeyChecking=no",
48 "-o", "UserKnownHostsFile=/dev/null",
49 "-o", "IdentityAgent=none",
50 "-o", "IdentitiesOnly=yes",
51 "-o", "ServerAliveInterval=60",
52 "-i", key,
53 }
54
55 check := func(ts *testscript.TestScript, err error, neg bool) {
56 if neg && err == nil {
57 ts.Fatalf("expected error, got nil")
58 }
59 if !neg {
60 ts.Check(err)
61 }
62 }
63
64 testscript.Run(t, testscript.Params{
65 Dir: "testdata/script",
66 UpdateScripts: *update,
67 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
68 "soft": func(ts *testscript.TestScript, neg bool, args []string) {
69 args = append(
70 sshArgs,
71 append([]string{
72 "-p", ts.Getenv("SSH_PORT"),
73 "localhost",
74 "--",
75 }, args...)...,
76 )
77 if runtime.GOOS == "windows" {
78 cmd := exec.Command("ssh", args...)
79 out, err := cmd.CombinedOutput()
80 ts.Logf("RUNNING %v: output: %s error: %v", cmd.Args, string(out), err)
81 }
82 check(ts, ts.Exec("ssh", args...), neg)
83 },
84 "git": func(ts *testscript.TestScript, neg bool, args []string) {
85 ts.Setenv(
86 "GIT_SSH_COMMAND",
87 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
88 )
89 args = append([]string{
90 "-c", "user.email=john@example.com",
91 "-c", "user.name=John Doe",
92 }, args...)
93 check(ts, ts.Exec("git", args...), neg)
94 },
95 "mkreadme": func(ts *testscript.TestScript, neg bool, args []string) {
96 if len(args) != 1 {
97 ts.Fatalf("must have exactly 1 arg, the filename, got %d", len(args))
98 }
99 check(ts, os.WriteFile(ts.MkAbs(args[0]), []byte("# example\ntest project"), 0o644), neg)
100 },
101 },
102 Setup: func(e *testscript.Env) error {
103 sshPort := test.RandomPort()
104 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
105 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
106 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
107 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
108 data := t.TempDir()
109 cfg := config.Config{
110 Name: "Test Soft Serve",
111 DataPath: data,
112 InitialAdminKeys: []string{admin1.AuthorizedKey()},
113 SSH: config.SSHConfig{
114 ListenAddr: fmt.Sprintf("localhost:%d", sshPort),
115 PublicURL: fmt.Sprintf("ssh://localhost:%d", sshPort),
116 KeyPath: filepath.Join(data, "ssh", "soft_serve_host_ed25519"),
117 ClientKeyPath: filepath.Join(data, "ssh", "soft_serve_client_ed25519"),
118 },
119 Git: config.GitConfig{
120 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
121 IdleTimeout: 3,
122 MaxConnections: 32,
123 },
124 HTTP: config.HTTPConfig{
125 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
126 PublicURL: fmt.Sprintf("http://localhost:%d", test.RandomPort()),
127 },
128 Stats: config.StatsConfig{
129 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
130 },
131 Log: config.LogConfig{
132 Format: "text",
133 TimeFormat: time.DateTime,
134 },
135 }
136 ctx := config.WithContext(context.Background(), &cfg)
137
138 // prevent race condition in lipgloss...
139 // this will probably be autofixed when we start using the colors
140 // from the ssh session instead of the server.
141 // XXX: take another look at this soon
142 lock.Lock()
143 srv, err := server.NewServer(ctx)
144 if err != nil {
145 return err
146 }
147 lock.Unlock()
148
149 go func() {
150 if err := srv.Start(); err != nil {
151 e.T().Fatal(err)
152 }
153 }()
154
155 e.Defer(func() {
156 ctx, cancel := context.WithTimeout(context.Background(), time.Second)
157 defer cancel()
158 if err := srv.Shutdown(ctx); err != nil {
159 e.T().Fatal(err)
160 }
161 })
162
163 // wait until the server is up
164 for {
165 conn, _ := net.DialTimeout(
166 "tcp",
167 net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
168 time.Second,
169 )
170 if conn != nil {
171 conn.Close()
172 break
173 }
174 }
175
176 return nil
177 },
178 })
179}