script_test.go

  1package testscript
  2
  3import (
  4	"bytes"
  5	"context"
  6	"flag"
  7	"fmt"
  8	"net"
  9	"os"
 10	"path/filepath"
 11	"runtime"
 12	"strings"
 13	"sync"
 14	"testing"
 15	"time"
 16
 17	"github.com/charmbracelet/keygen"
 18	"github.com/charmbracelet/soft-serve/server"
 19	"github.com/charmbracelet/soft-serve/server/config"
 20	"github.com/charmbracelet/soft-serve/server/test"
 21	"github.com/rogpeppe/go-internal/testscript"
 22	"golang.org/x/crypto/ssh"
 23)
 24
 25var update = flag.Bool("update", false, "update script files")
 26
 27func TestScript(t *testing.T) {
 28	flag.Parse()
 29	var lock sync.Mutex
 30
 31	mkkey := func(name string) (string, *keygen.SSHKeyPair) {
 32		path := filepath.Join(t.TempDir(), name)
 33		pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
 34		if err != nil {
 35			t.Fatal(err)
 36		}
 37		return path, pair
 38	}
 39
 40	key, admin1 := mkkey("admin1")
 41	_, admin2 := mkkey("admin2")
 42	_, user1 := mkkey("user1")
 43
 44	testscript.Run(t, testscript.Params{
 45		Dir:           "./testdata/",
 46		UpdateScripts: *update,
 47		Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
 48			"soft":     cmdSoft(admin1.Signer()),
 49			"git":      cmdGit(key),
 50			"mkreadme": cmdMkReadme,
 51			"dos2unix": cmdDos2Unix,
 52		},
 53		Setup: func(e *testscript.Env) error {
 54			sshPort := test.RandomPort()
 55			e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
 56			e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
 57			e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
 58			e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
 59			data := t.TempDir()
 60			cfg := config.Config{
 61				Name:             "Test Soft Serve",
 62				DataPath:         data,
 63				InitialAdminKeys: []string{admin1.AuthorizedKey()},
 64				SSH: config.SSHConfig{
 65					ListenAddr:    fmt.Sprintf("localhost:%d", sshPort),
 66					PublicURL:     fmt.Sprintf("ssh://localhost:%d", sshPort),
 67					KeyPath:       filepath.Join(data, "ssh", "soft_serve_host_ed25519"),
 68					ClientKeyPath: filepath.Join(data, "ssh", "soft_serve_client_ed25519"),
 69				},
 70				Git: config.GitConfig{
 71					ListenAddr:     fmt.Sprintf("localhost:%d", test.RandomPort()),
 72					IdleTimeout:    3,
 73					MaxConnections: 32,
 74				},
 75				HTTP: config.HTTPConfig{
 76					ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
 77					PublicURL:  fmt.Sprintf("http://localhost:%d", test.RandomPort()),
 78				},
 79				Stats: config.StatsConfig{
 80					ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
 81				},
 82				Log: config.LogConfig{
 83					Format:     "text",
 84					TimeFormat: time.DateTime,
 85				},
 86			}
 87			ctx := config.WithContext(context.Background(), &cfg)
 88
 89			// prevent race condition in lipgloss...
 90			// this will probably be autofixed when we start using the colors
 91			// from the ssh session instead of the server.
 92			// XXX: take another look at this soon
 93			lock.Lock()
 94			srv, err := server.NewServer(ctx)
 95			if err != nil {
 96				return err
 97			}
 98			lock.Unlock()
 99
100			go func() {
101				if err := srv.Start(); err != nil {
102					e.T().Fatal(err)
103				}
104			}()
105
106			e.Defer(func() {
107				ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
108				defer cancel()
109				if err := srv.Shutdown(ctx); err != nil {
110					e.T().Fatal(err)
111				}
112			})
113
114			// wait until the server is up
115			for {
116				conn, _ := net.DialTimeout(
117					"tcp",
118					net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
119					time.Second,
120				)
121				if conn != nil {
122					conn.Close()
123					break
124				}
125			}
126
127			return nil
128		},
129	})
130}
131
132func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
133	return func(ts *testscript.TestScript, neg bool, args []string) {
134		cli, err := ssh.Dial(
135			"tcp",
136			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
137			&ssh.ClientConfig{
138				User:            "admin",
139				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
140				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
141			},
142		)
143		ts.Check(err)
144		defer cli.Close()
145
146		sess, err := cli.NewSession()
147		ts.Check(err)
148		defer sess.Close()
149
150		sess.Stdout = ts.Stdout()
151		sess.Stderr = ts.Stderr()
152
153		check(ts, sess.Run(strings.Join(args, " ")), neg)
154	}
155}
156
157// P.S. Windows sucks!
158func cmdDos2Unix(ts *testscript.TestScript, neg bool, args []string) {
159	if neg {
160		ts.Fatalf("unsupported: ! dos2unix")
161	}
162	if len(args) < 1 {
163		ts.Fatalf("usage: dos2unix paths...")
164	}
165	for _, arg := range args {
166		filename := ts.MkAbs(arg)
167		data, err := os.ReadFile(filename)
168		if err != nil {
169			ts.Fatalf("%s: %v", filename, err)
170		}
171
172		// Replace all '\r\n' with '\n'.
173		data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
174
175		if err := os.WriteFile(filename, data, 0o644); err != nil {
176			ts.Fatalf("%s: %v", filename, err)
177		}
178	}
179}
180
181func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
182	return func(ts *testscript.TestScript, neg bool, args []string) {
183		sshArgs := []string{
184			"-o", "StrictHostKeyChecking=no",
185			"-o", "IdentityAgent=none",
186			"-o", "IdentitiesOnly=yes",
187			"-o", "ServerAliveInterval=60",
188			// Escape the key path for Windows.
189			"-i", strings.ReplaceAll(key, `\`, `\\`),
190		}
191		// Windows null device
192		// https://stackoverflow.com/a/36746090/10913628
193		if runtime.GOOS == "windows" {
194			sshArgs = append(sshArgs, []string{
195				"-F", `$nul`,
196				"-o", `UserKnownHostsFile=$nul`,
197			}...)
198		} else {
199			sshArgs = append(sshArgs, []string{
200				"-F", "/dev/null",
201				"-o", "UserKnownHostsFile=/dev/null",
202			}...)
203		}
204		ts.Setenv(
205			"GIT_SSH_COMMAND",
206			strings.Join(append([]string{"ssh"}, sshArgs...), " "),
207		)
208		args = append([]string{
209			"-c", "user.email=john@example.com",
210			"-c", "user.name=John Doe",
211		}, args...)
212		check(ts, ts.Exec("git", args...), neg)
213	}
214}
215
216func cmdMkReadme(ts *testscript.TestScript, neg bool, args []string) {
217	if len(args) != 1 {
218		ts.Fatalf("usage: mkreadme path")
219	}
220	content := []byte("# example\ntest project")
221	check(ts, os.WriteFile(ts.MkAbs(args[0]), content, 0o644), neg)
222}
223
224func check(ts *testscript.TestScript, err error, neg bool) {
225	if neg && err == nil {
226		ts.Fatalf("expected error, got nil")
227	}
228	if !neg {
229		ts.Check(err)
230	}
231}