script_test.go

  1package testscript
  2
  3import (
  4	"bytes"
  5	"context"
  6	"flag"
  7	"fmt"
  8	"net"
  9	"os"
 10	"path/filepath"
 11	"strings"
 12	"sync"
 13	"testing"
 14	"time"
 15
 16	"github.com/charmbracelet/keygen"
 17	"github.com/charmbracelet/soft-serve/server"
 18	"github.com/charmbracelet/soft-serve/server/config"
 19	"github.com/charmbracelet/soft-serve/server/test"
 20	"github.com/rogpeppe/go-internal/testscript"
 21	"golang.org/x/crypto/ssh"
 22)
 23
 24var update = flag.Bool("update", false, "update script files")
 25
 26func TestScript(t *testing.T) {
 27	flag.Parse()
 28	var lock sync.Mutex
 29
 30	mkkey := func(name string) (string, *keygen.SSHKeyPair) {
 31		path := filepath.Join(t.TempDir(), name)
 32		pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
 33		if err != nil {
 34			t.Fatal(err)
 35		}
 36		return path, pair
 37	}
 38
 39	key, admin1 := mkkey("admin1")
 40	_, admin2 := mkkey("admin2")
 41	_, user1 := mkkey("user1")
 42
 43	testscript.Run(t, testscript.Params{
 44		Dir:           "./testdata/",
 45		UpdateScripts: *update,
 46		Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
 47			"soft":     cmdSoft(admin1.Signer()),
 48			"git":      cmdGit(key),
 49			"mkreadme": cmdMkReadme,
 50			"unix2dos": cmdUnix2Dos,
 51		},
 52		Setup: func(e *testscript.Env) error {
 53			sshPort := test.RandomPort()
 54			e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
 55			e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
 56			e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
 57			e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
 58			data := t.TempDir()
 59			cfg := config.Config{
 60				Name:             "Test Soft Serve",
 61				DataPath:         data,
 62				InitialAdminKeys: []string{admin1.AuthorizedKey()},
 63				SSH: config.SSHConfig{
 64					ListenAddr:    fmt.Sprintf("localhost:%d", sshPort),
 65					PublicURL:     fmt.Sprintf("ssh://localhost:%d", sshPort),
 66					KeyPath:       filepath.Join(data, "ssh", "soft_serve_host_ed25519"),
 67					ClientKeyPath: filepath.Join(data, "ssh", "soft_serve_client_ed25519"),
 68				},
 69				Git: config.GitConfig{
 70					ListenAddr:     fmt.Sprintf("localhost:%d", test.RandomPort()),
 71					IdleTimeout:    3,
 72					MaxConnections: 32,
 73				},
 74				HTTP: config.HTTPConfig{
 75					ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
 76					PublicURL:  fmt.Sprintf("http://localhost:%d", test.RandomPort()),
 77				},
 78				Stats: config.StatsConfig{
 79					ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
 80				},
 81				Log: config.LogConfig{
 82					Format:     "text",
 83					TimeFormat: time.DateTime,
 84				},
 85			}
 86			ctx := config.WithContext(context.Background(), &cfg)
 87
 88			// prevent race condition in lipgloss...
 89			// this will probably be autofixed when we start using the colors
 90			// from the ssh session instead of the server.
 91			// XXX: take another look at this soon
 92			lock.Lock()
 93			srv, err := server.NewServer(ctx)
 94			if err != nil {
 95				return err
 96			}
 97			lock.Unlock()
 98
 99			go func() {
100				if err := srv.Start(); err != nil {
101					e.T().Fatal(err)
102				}
103			}()
104
105			e.Defer(func() {
106				ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
107				defer cancel()
108				if err := srv.Shutdown(ctx); err != nil {
109					e.T().Fatal(err)
110				}
111			})
112
113			// wait until the server is up
114			for {
115				conn, _ := net.DialTimeout(
116					"tcp",
117					net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
118					time.Second,
119				)
120				if conn != nil {
121					conn.Close()
122					break
123				}
124			}
125
126			return nil
127		},
128	})
129}
130
131func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
132	return func(ts *testscript.TestScript, neg bool, args []string) {
133		cli, err := ssh.Dial(
134			"tcp",
135			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
136			&ssh.ClientConfig{
137				User:            "admin",
138				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
139				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
140			},
141		)
142		ts.Check(err)
143		defer cli.Close()
144
145		sess, err := cli.NewSession()
146		ts.Check(err)
147		defer sess.Close()
148
149		sess.Stdout = ts.Stdout()
150		sess.Stderr = ts.Stderr()
151
152		check(ts, sess.Run(strings.Join(args, " ")), neg)
153	}
154}
155
156func cmdUnix2Dos(ts *testscript.TestScript, neg bool, args []string) {
157	if neg {
158		ts.Fatalf("unsupported: ! unix2dos")
159	}
160	if len(args) < 1 {
161		ts.Fatalf("usage: unix2dos paths...")
162	}
163	for _, arg := range args {
164		filename := ts.MkAbs(arg)
165		data, err := os.ReadFile(filename)
166		if err != nil {
167			ts.Fatalf("%s: %v", filename, err)
168		}
169
170		// First ensure we don't have any `\r\n` there already then replace all
171		// `\n` with `\r\n`.
172		// This should prevent creating `\r\r\n`.
173		data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
174		data = bytes.ReplaceAll(data, []byte{'\n'}, []byte{'\r', '\n'})
175
176		if err := os.WriteFile(filename, data, 0o644); err != nil {
177			ts.Fatalf("%s: %v", filename, err)
178		}
179	}
180}
181
182func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
183	return func(ts *testscript.TestScript, neg bool, args []string) {
184		sshArgs := []string{
185			"-F", "/dev/null",
186			"-o", "StrictHostKeyChecking=no",
187			"-o", "UserKnownHostsFile=/dev/null",
188			"-o", "IdentityAgent=none",
189			"-o", "IdentitiesOnly=yes",
190			"-o", "ServerAliveInterval=60",
191			"-i", key,
192		}
193		ts.Setenv(
194			"GIT_SSH_COMMAND",
195			strings.Join(append([]string{"ssh"}, sshArgs...), " "),
196		)
197		args = append([]string{
198			"-c", "user.email=john@example.com",
199			"-c", "user.name=John Doe",
200		}, args...)
201		check(ts, ts.Exec("git", args...), neg)
202	}
203}
204
205func cmdMkReadme(ts *testscript.TestScript, neg bool, args []string) {
206	if len(args) != 1 {
207		ts.Fatalf("usage: mkreadme path")
208	}
209	content := []byte("# example\ntest project")
210	check(ts, os.WriteFile(ts.MkAbs(args[0]), content, 0o644), neg)
211}
212
213func check(ts *testscript.TestScript, err error, neg bool) {
214	if neg && err == nil {
215		ts.Fatalf("expected error, got nil")
216	}
217	if !neg {
218		ts.Check(err)
219	}
220}