1package testscript
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "net"
8 "os"
9 "os/exec"
10 "path/filepath"
11 "runtime"
12 "strings"
13 "sync"
14 "testing"
15 "time"
16
17 "github.com/charmbracelet/keygen"
18 "github.com/charmbracelet/soft-serve/server"
19 "github.com/charmbracelet/soft-serve/server/config"
20 "github.com/charmbracelet/soft-serve/server/test"
21 "github.com/rogpeppe/go-internal/testscript"
22)
23
24var update = flag.Bool("update", false, "update script files")
25
26func TestScript(t *testing.T) {
27 flag.Parse()
28 var lock sync.Mutex
29
30 t.Setenv("SOFT_SERVE_TEST_NO_HOOKS", "1")
31
32 mkkey := func(name string) (string, *keygen.SSHKeyPair) {
33 path := filepath.Join(t.TempDir(), name)
34 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
35 if err != nil {
36 t.Fatal(err)
37 }
38 return path, pair
39 }
40
41 key, admin1 := mkkey("admin1")
42 _, admin2 := mkkey("admin2")
43 _, user1 := mkkey("user1")
44
45 sshArgs := []string{
46 "-F", "/dev/null",
47 "-o", "StrictHostKeyChecking=no",
48 "-o", "UserKnownHostsFile=/dev/null",
49 "-o", "IdentityAgent=none",
50 "-o", "IdentitiesOnly=yes",
51 "-o", "ServerAliveInterval=60",
52 "-i", key,
53 }
54
55 check := func(ts *testscript.TestScript, err error, neg bool) {
56 if neg && err == nil {
57 ts.Fatalf("expected error, got nil")
58 }
59 if !neg {
60 ts.Check(err)
61 }
62 }
63
64 testscript.Run(t, testscript.Params{
65 Dir: "./testdata/",
66 UpdateScripts: *update,
67 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
68 "soft": func(ts *testscript.TestScript, neg bool, args []string) {
69 args = append(
70 sshArgs,
71 append([]string{
72 "-p", ts.Getenv("SSH_PORT"),
73 "localhost",
74 "--",
75 }, args...)...,
76 )
77 if runtime.GOOS == "windows" {
78 cmd := exec.Command("ssh.exe", args...)
79 out, err := cmd.CombinedOutput()
80 ts.Logf("WINDOWS RAN %v:\n\tOUTPUT: %s\n\tERROR: %v", cmd.Args, string(out), err)
81 check(ts, err, neg)
82 } else {
83 check(ts, ts.Exec("ssh", args...), neg)
84 }
85 },
86 "git": func(ts *testscript.TestScript, neg bool, args []string) {
87 ts.Setenv(
88 "GIT_SSH_COMMAND",
89 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
90 )
91 args = append([]string{
92 "-c", "user.email=john@example.com",
93 "-c", "user.name=John Doe",
94 }, args...)
95 check(ts, ts.Exec("git", args...), neg)
96 },
97 "mkreadme": func(ts *testscript.TestScript, neg bool, args []string) {
98 if len(args) != 1 {
99 ts.Fatalf("must have exactly 1 arg, the filename, got %d", len(args))
100 }
101 check(ts, os.WriteFile(ts.MkAbs(args[0]), []byte("# example\ntest project"), 0o644), neg)
102 },
103 },
104 Setup: func(e *testscript.Env) error {
105 sshPort := test.RandomPort()
106 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
107 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
108 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
109 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
110 data := t.TempDir()
111 cfg := config.Config{
112 Name: "Test Soft Serve",
113 DataPath: data,
114 InitialAdminKeys: []string{admin1.AuthorizedKey()},
115 SSH: config.SSHConfig{
116 ListenAddr: fmt.Sprintf("localhost:%d", sshPort),
117 PublicURL: fmt.Sprintf("ssh://localhost:%d", sshPort),
118 KeyPath: filepath.Join(data, "ssh", "soft_serve_host_ed25519"),
119 ClientKeyPath: filepath.Join(data, "ssh", "soft_serve_client_ed25519"),
120 },
121 Git: config.GitConfig{
122 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
123 IdleTimeout: 3,
124 MaxConnections: 32,
125 },
126 HTTP: config.HTTPConfig{
127 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
128 PublicURL: fmt.Sprintf("http://localhost:%d", test.RandomPort()),
129 },
130 Stats: config.StatsConfig{
131 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
132 },
133 Log: config.LogConfig{
134 Format: "text",
135 TimeFormat: time.DateTime,
136 },
137 }
138 ctx := config.WithContext(context.Background(), &cfg)
139
140 // prevent race condition in lipgloss...
141 // this will probably be autofixed when we start using the colors
142 // from the ssh session instead of the server.
143 // XXX: take another look at this soon
144 lock.Lock()
145 srv, err := server.NewServer(ctx)
146 if err != nil {
147 return err
148 }
149 lock.Unlock()
150
151 go func() {
152 if err := srv.Start(); err != nil {
153 e.T().Fatal(err)
154 }
155 }()
156
157 e.Defer(func() {
158 ctx, cancel := context.WithTimeout(context.Background(), time.Second)
159 defer cancel()
160 if err := srv.Shutdown(ctx); err != nil {
161 e.T().Fatal(err)
162 }
163 })
164
165 // wait until the server is up
166 for {
167 conn, _ := net.DialTimeout(
168 "tcp",
169 net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
170 time.Second,
171 )
172 if conn != nil {
173 conn.Close()
174 break
175 }
176 }
177
178 return nil
179 },
180 })
181}