1package testscript
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "net"
8 "os"
9 "path/filepath"
10 "strings"
11 "sync"
12 "testing"
13 "time"
14
15 "github.com/charmbracelet/keygen"
16 "github.com/charmbracelet/soft-serve/server"
17 "github.com/charmbracelet/soft-serve/server/config"
18 "github.com/charmbracelet/soft-serve/server/test"
19 "github.com/rogpeppe/go-internal/testscript"
20)
21
22var update = flag.Bool("update", false, "update script files")
23
24func TestScript(t *testing.T) {
25 flag.Parse()
26 var lock sync.Mutex
27
28 t.Setenv("SOFT_SERVE_TEST_NO_HOOKS", "1")
29
30 mkkey := func(name string) (string, *keygen.SSHKeyPair) {
31 path := filepath.Join(t.TempDir(), name)
32 pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
33 if err != nil {
34 t.Fatal(err)
35 }
36 return path, pair
37 }
38
39 key, admin1 := mkkey("admin1")
40 _, admin2 := mkkey("admin2")
41 _, user1 := mkkey("user1")
42
43 sshArgs := []string{
44 "-F", "/dev/null",
45 "-o", "StrictHostKeyChecking=no",
46 "-o", "UserKnownHostsFile=/dev/null",
47 "-o", "IdentityAgent=none",
48 "-o", "IdentitiesOnly=yes",
49 "-o", "ServerAliveInterval=60",
50 "-i", key,
51 }
52
53 check := func(ts *testscript.TestScript, err error, neg bool) {
54 if neg && err == nil {
55 ts.Fatalf("expected error, got nil")
56 }
57 if !neg {
58 ts.Check(err)
59 }
60 }
61
62 testscript.Run(t, testscript.Params{
63 Dir: "./testdata/",
64 UpdateScripts: *update,
65 Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
66 "soft": func(ts *testscript.TestScript, neg bool, args []string) {
67 args = append(
68 sshArgs,
69 append([]string{
70 "-p", ts.Getenv("SSH_PORT"),
71 "localhost",
72 "--",
73 }, args...)...,
74 )
75 check(ts, ts.Exec("ssh", args...), neg)
76 },
77 "git": func(ts *testscript.TestScript, neg bool, args []string) {
78 ts.Setenv(
79 "GIT_SSH_COMMAND",
80 strings.Join(append([]string{"ssh"}, sshArgs...), " "),
81 )
82 args = append([]string{
83 "-c", "user.email=john@example.com",
84 "-c", "user.name=John Doe",
85 }, args...)
86 check(ts, ts.Exec("git", args...), neg)
87 },
88 "mkreadme": func(ts *testscript.TestScript, neg bool, args []string) {
89 if len(args) != 1 {
90 ts.Fatalf("must have exactly 1 arg, the filename, got %d", len(args))
91 }
92 check(ts, os.WriteFile(ts.MkAbs(args[0]), []byte("# example\ntest project"), 0o644), neg)
93 },
94 },
95 Setup: func(e *testscript.Env) error {
96 sshPort := test.RandomPort()
97 e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
98 e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
99 e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
100 e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
101 data := t.TempDir()
102 cfg := config.Config{
103 Name: "Test Soft Serve",
104 DataPath: data,
105 InitialAdminKeys: []string{admin1.AuthorizedKey()},
106 SSH: config.SSHConfig{
107 ListenAddr: fmt.Sprintf("localhost:%d", sshPort),
108 PublicURL: fmt.Sprintf("ssh://localhost:%d", sshPort),
109 KeyPath: filepath.Join(data, "ssh", "soft_serve_host_ed25519"),
110 ClientKeyPath: filepath.Join(data, "ssh", "soft_serve_client_ed25519"),
111 },
112 Git: config.GitConfig{
113 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
114 IdleTimeout: 3,
115 MaxConnections: 32,
116 },
117 HTTP: config.HTTPConfig{
118 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
119 PublicURL: fmt.Sprintf("http://localhost:%d", test.RandomPort()),
120 },
121 Stats: config.StatsConfig{
122 ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
123 },
124 Log: config.LogConfig{
125 Format: "text",
126 TimeFormat: time.DateTime,
127 },
128 }
129 ctx := config.WithContext(context.Background(), &cfg)
130
131 // prevent race condition in lipgloss...
132 // this will probably be autofixed when we start using the colors
133 // from the ssh session instead of the server.
134 // XXX: take another look at this soon
135 lock.Lock()
136 srv, err := server.NewServer(ctx)
137 if err != nil {
138 return err
139 }
140 lock.Unlock()
141
142 go func() {
143 if err := srv.Start(); err != nil {
144 e.T().Fatal(err)
145 }
146 }()
147
148 e.Defer(func() {
149 ctx, cancel := context.WithTimeout(context.Background(), time.Second)
150 defer cancel()
151 if err := srv.Shutdown(ctx); err != nil {
152 e.T().Fatal(err)
153 }
154 })
155
156 // wait until the server is up
157 for {
158 conn, _ := net.DialTimeout(
159 "tcp",
160 net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
161 time.Second,
162 )
163 if conn != nil {
164 conn.Close()
165 break
166 }
167 }
168
169 return nil
170 },
171 })
172}