script_test.go

  1package testscript
  2
  3import (
  4	"bytes"
  5	"context"
  6	"flag"
  7	"fmt"
  8	"net"
  9	"os"
 10	"path/filepath"
 11	"strings"
 12	"sync"
 13	"testing"
 14	"time"
 15
 16	"github.com/charmbracelet/keygen"
 17	"github.com/charmbracelet/soft-serve/server"
 18	"github.com/charmbracelet/soft-serve/server/config"
 19	"github.com/charmbracelet/soft-serve/server/test"
 20	"github.com/rogpeppe/go-internal/testscript"
 21	"golang.org/x/crypto/ssh"
 22)
 23
 24var update = flag.Bool("update", false, "update script files")
 25
 26func TestScript(t *testing.T) {
 27	flag.Parse()
 28	var lock sync.Mutex
 29
 30	mkkey := func(name string) (string, *keygen.SSHKeyPair) {
 31		path := filepath.Join(t.TempDir(), name)
 32		pair, err := keygen.New(path, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
 33		if err != nil {
 34			t.Fatal(err)
 35		}
 36		return path, pair
 37	}
 38
 39	key, admin1 := mkkey("admin1")
 40	_, admin2 := mkkey("admin2")
 41	_, user1 := mkkey("user1")
 42
 43	testscript.Run(t, testscript.Params{
 44		Dir:           "./testdata/",
 45		UpdateScripts: *update,
 46		Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){
 47			"soft":     cmdSoft(admin1.Signer()),
 48			"git":      cmdGit(key),
 49			"mkreadme": cmdMkReadme,
 50			"unix2dos": cmdUnix2Dos,
 51		},
 52		Setup: func(e *testscript.Env) error {
 53			sshPort := test.RandomPort()
 54			e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
 55			e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
 56			e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
 57			e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
 58			data := t.TempDir()
 59			cfg := config.Config{
 60				Name:             "Test Soft Serve",
 61				DataPath:         data,
 62				InitialAdminKeys: []string{admin1.AuthorizedKey()},
 63				SSH: config.SSHConfig{
 64					ListenAddr:    fmt.Sprintf("localhost:%d", sshPort),
 65					PublicURL:     fmt.Sprintf("ssh://localhost:%d", sshPort),
 66					KeyPath:       filepath.Join(data, "ssh", "soft_serve_host_ed25519"),
 67					ClientKeyPath: filepath.Join(data, "ssh", "soft_serve_client_ed25519"),
 68				},
 69				Git: config.GitConfig{
 70					ListenAddr:     fmt.Sprintf("localhost:%d", test.RandomPort()),
 71					IdleTimeout:    3,
 72					MaxConnections: 32,
 73				},
 74				HTTP: config.HTTPConfig{
 75					ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
 76					PublicURL:  fmt.Sprintf("http://localhost:%d", test.RandomPort()),
 77				},
 78				Stats: config.StatsConfig{
 79					ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
 80				},
 81				Log: config.LogConfig{
 82					Format:     "text",
 83					TimeFormat: time.DateTime,
 84				},
 85			}
 86			ctx := config.WithContext(context.Background(), &cfg)
 87
 88			// prevent race condition in lipgloss...
 89			// this will probably be autofixed when we start using the colors
 90			// from the ssh session instead of the server.
 91			// XXX: take another look at this soon
 92			lock.Lock()
 93			srv, err := server.NewServer(ctx)
 94			if err != nil {
 95				return err
 96			}
 97			lock.Unlock()
 98
 99			go func() {
100				if err := srv.Start(); err != nil {
101					e.T().Fatal(err)
102				}
103			}()
104
105			e.Defer(func() {
106				ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
107				defer cancel()
108				if err := srv.Shutdown(ctx); err != nil {
109					e.T().Fatal(err)
110				}
111			})
112
113			// wait until the server is up
114			for {
115				conn, _ := net.DialTimeout(
116					"tcp",
117					net.JoinHostPort("localhost", fmt.Sprintf("%d", sshPort)),
118					time.Second,
119				)
120				if conn != nil {
121					conn.Close()
122					break
123				}
124			}
125
126			return nil
127		},
128	})
129}
130
131func cmdSoft(key ssh.Signer) func(ts *testscript.TestScript, neg bool, args []string) {
132	return func(ts *testscript.TestScript, neg bool, args []string) {
133		cli, err := ssh.Dial(
134			"tcp",
135			net.JoinHostPort("localhost", ts.Getenv("SSH_PORT")),
136			&ssh.ClientConfig{
137				User:            "admin",
138				Auth:            []ssh.AuthMethod{ssh.PublicKeys(key)},
139				HostKeyCallback: ssh.InsecureIgnoreHostKey(),
140			},
141		)
142		ts.Check(err)
143		defer cli.Close()
144
145		sess, err := cli.NewSession()
146		ts.Check(err)
147		defer sess.Close()
148
149		sess.Stdout = ts.Stdout()
150		sess.Stderr = ts.Stderr()
151
152		check(ts, sess.Run(strings.Join(args, " ")), neg)
153	}
154}
155
156func cmdUnix2Dos(ts *testscript.TestScript, neg bool, args []string) {
157	// now this should not be needed indeed
158	return
159	if neg {
160		ts.Fatalf("unsupported: ! unix2dos")
161	}
162	if len(args) < 1 {
163		ts.Fatalf("usage: unix2dos paths...")
164	}
165	for _, arg := range args {
166		filename := ts.MkAbs(arg)
167		data, err := os.ReadFile(filename)
168		if err != nil {
169			ts.Fatalf("%s: %v", filename, err)
170		}
171
172		// First ensure we don't have any `\r\n` there already then replace all
173		// `\n` with `\r\n`.
174		// This should prevent creating `\r\r\n`.
175		data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
176		data = bytes.ReplaceAll(data, []byte{'\n'}, []byte{'\r', '\n'})
177
178		if err := os.WriteFile(filename, data, 0o644); err != nil {
179			ts.Fatalf("%s: %v", filename, err)
180		}
181	}
182}
183
184func cmdGit(key string) func(ts *testscript.TestScript, neg bool, args []string) {
185	return func(ts *testscript.TestScript, neg bool, args []string) {
186		sshArgs := []string{
187			"-F", "/dev/null",
188			"-o", "StrictHostKeyChecking=no",
189			"-o", "UserKnownHostsFile=/dev/null",
190			"-o", "IdentityAgent=none",
191			"-o", "IdentitiesOnly=yes",
192			"-o", "ServerAliveInterval=60",
193			"-i", key,
194		}
195		ts.Setenv(
196			"GIT_SSH_COMMAND",
197			strings.Join(append([]string{"ssh"}, sshArgs...), " "),
198		)
199		args = append([]string{
200			"-c", "user.email=john@example.com",
201			"-c", "user.name=John Doe",
202		}, args...)
203		check(ts, ts.Exec("git", args...), neg)
204	}
205}
206
207func cmdMkReadme(ts *testscript.TestScript, neg bool, args []string) {
208	if len(args) != 1 {
209		ts.Fatalf("usage: mkreadme path")
210	}
211	content := []byte("# example\ntest project")
212	check(ts, os.WriteFile(ts.MkAbs(args[0]), content, 0o644), neg)
213}
214
215func check(ts *testscript.TestScript, err error, neg bool) {
216	if neg && err == nil {
217		ts.Fatalf("expected error, got nil")
218	}
219	if !neg {
220		ts.Check(err)
221	}
222}